From 62db346b49f8ccda1352af31466c81e1327c23bb Mon Sep 17 00:00:00 2001 From: Aayush Goel <81844215+Aayush-Goel-04@users.noreply.github.com> Date: Thu, 6 Jul 2023 05:28:13 +0530 Subject: [PATCH] Style , mypy checks --- capa/render/proto/__init__.py | 3 ++- capa/rules/cache.py | 1 + scripts/cache-ruleset.py | 6 ++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/capa/render/proto/__init__.py b/capa/render/proto/__init__.py index a5a3f04d..9aae03ca 100644 --- a/capa/render/proto/__init__.py +++ b/capa/render/proto/__init__.py @@ -29,6 +29,7 @@ import json import argparse import datetime from typing import Any, Dict, Union +from pathlib import Path import google.protobuf.json_format from google.protobuf.json_format import MessageToJson @@ -500,7 +501,7 @@ def metadata_from_pb2(meta: capa_pb2.Metadata) -> rd.Metadata: arch=meta.analysis.arch, os=meta.analysis.os, extractor=meta.analysis.extractor, - rules=tuple(meta.analysis.rules), + rules=tuple(Path(r) for r in meta.analysis.rules), base_address=addr_from_pb2(meta.analysis.base_address), layout=rd.Layout( functions=tuple( diff --git a/capa/rules/cache.py b/capa/rules/cache.py index 2d49c407..9ff9fbb7 100644 --- a/capa/rules/cache.py +++ b/capa/rules/cache.py @@ -1,3 +1,4 @@ +import os import sys import zlib import pickle diff --git a/scripts/cache-ruleset.py b/scripts/cache-ruleset.py index a2a49bdb..76dd3fd8 100644 --- a/scripts/cache-ruleset.py +++ b/scripts/cache-ruleset.py @@ -20,6 +20,7 @@ import sys import time import logging import argparse +from pathlib import Path import capa.main import capa.rules @@ -48,8 +49,9 @@ def main(argv=None): logging.getLogger("capa").setLevel(logging.ERROR) try: - os.makedirs(args.cache, exist_ok=True) - rules = capa.main.get_rules(args.rules, cache_dir=args.cache) + cache_dir = Path(args.cache) + cache_dir.mkdir(parents=True, exist_ok=True) + rules = capa.main.get_rules(args.rules, cache_dir) logger.info("successfully loaded %s rules", len(rules)) except (IOError, capa.rules.InvalidRule, capa.rules.InvalidRuleSet) as e: logger.error("%s", str(e))