From b172f9a3544a05dda1348e71b284cc86d549a4ea Mon Sep 17 00:00:00 2001 From: Yacine Elhamer Date: Mon, 26 Jun 2023 22:46:27 +0100 Subject: [PATCH] FeatureExtractor alias: fix mypy typing issues by adding ininstance-based assert statements --- capa/features/freeze/__init__.py | 11 ++++++----- scripts/profile-time.py | 5 +++-- scripts/show-capabilities-by-function.py | 4 ++-- scripts/show-features.py | 8 ++++---- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/capa/features/freeze/__init__.py b/capa/features/freeze/__init__.py index e6ed9fe1..b29c1bb0 100644 --- a/capa/features/freeze/__init__.py +++ b/capa/features/freeze/__init__.py @@ -23,9 +23,9 @@ import capa.features.insn import capa.features.common import capa.features.address import capa.features.basicblock -import capa.features.extractors.base_extractor from capa.helpers import assert_never from capa.features.freeze.features import Feature, feature_from_capa +from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor logger = logging.getLogger(__name__) @@ -226,7 +226,7 @@ class Freeze(BaseModel): allow_population_by_field_name = True -def dumps(extractor: capa.features.extractors.base_extractor.StaticFeatureExtractor) -> str: +def dumps(extractor: StaticFeatureExtractor) -> str: """ serialize the given extractor to a string """ @@ -327,7 +327,7 @@ def dumps(extractor: capa.features.extractors.base_extractor.StaticFeatureExtrac return freeze.json() -def loads(s: str) -> capa.features.extractors.base_extractor.StaticFeatureExtractor: +def loads(s: str) -> StaticFeatureExtractor: """deserialize a set of features (as a NullFeatureExtractor) from a string.""" import capa.features.extractors.null as null @@ -363,8 +363,9 @@ def loads(s: str) -> capa.features.extractors.base_extractor.StaticFeatureExtrac MAGIC = "capa0000".encode("ascii") -def dump(extractor: capa.features.extractors.base_extractor.StaticFeatureExtractor) -> bytes: +def dump(extractor: FeatureExtractor) -> bytes: """serialize the given extractor to a byte array.""" + assert isinstance(extractor, StaticFeatureExtractor) return MAGIC + zlib.compress(dumps(extractor).encode("utf-8")) @@ -372,7 +373,7 @@ def is_freeze(buf: bytes) -> bool: return buf[: len(MAGIC)] == MAGIC -def load(buf: bytes) -> capa.features.extractors.base_extractor.StaticFeatureExtractor: +def load(buf: bytes) -> StaticFeatureExtractor: """deserialize a set of features (as a NullFeatureExtractor) from a byte array.""" if not is_freeze(buf): raise ValueError("missing magic header") diff --git a/scripts/profile-time.py b/scripts/profile-time.py index 2566a0fe..32aa31f7 100644 --- a/scripts/profile-time.py +++ b/scripts/profile-time.py @@ -46,7 +46,7 @@ import capa.helpers import capa.features import capa.features.common import capa.features.freeze -from capa.features.extractors.base_extractor import StaticFeatureExtractor +from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor logger = logging.getLogger("capa.profile") @@ -104,13 +104,14 @@ def main(argv=None): args.format == capa.features.common.FORMAT_AUTO and capa.features.freeze.is_freeze(taste) ): with open(args.sample, "rb") as f: - extractor = capa.features.freeze.load(f.read()) + extractor: FeatureExtractor = capa.features.freeze.load(f.read()) assert isinstance(extractor, StaticFeatureExtractor) else: extractor = capa.main.get_extractor( args.sample, args.format, args.os, capa.main.BACKEND_VIV, sig_paths, should_save_workspace=False ) + assert isinstance(extractor, StaticFeatureExtractor) with tqdm.tqdm(total=args.number * args.repeat) as pbar: def do_iteration(): diff --git a/scripts/show-capabilities-by-function.py b/scripts/show-capabilities-by-function.py index 7be4b99f..c5bfd571 100644 --- a/scripts/show-capabilities-by-function.py +++ b/scripts/show-capabilities-by-function.py @@ -70,7 +70,7 @@ import capa.render.result_document as rd from capa.helpers import get_file_taste from capa.features.common import FORMAT_AUTO from capa.features.freeze import Address -from capa.features.extractors.base_extractor import StaticFeatureExtractor +from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor logger = logging.getLogger("capa.show-capabilities-by-function") @@ -161,7 +161,7 @@ def main(argv=None): if (args.format == "freeze") or (args.format == FORMAT_AUTO and capa.features.freeze.is_freeze(taste)): format_ = "freeze" with open(args.sample, "rb") as f: - extractor = capa.features.freeze.load(f.read()) + extractor: FeatureExtractor = capa.features.freeze.load(f.read()) else: format_ = args.format should_save_workspace = os.environ.get("CAPA_SAVE_WORKSPACE") not in ("0", "no", "NO", "n", None) diff --git a/scripts/show-features.py b/scripts/show-features.py index 583f757e..023701bb 100644 --- a/scripts/show-features.py +++ b/scripts/show-features.py @@ -80,8 +80,8 @@ import capa.render.verbose as v import capa.features.common import capa.features.freeze import capa.features.address -import capa.features.extractors.base_extractor from capa.helpers import log_unsupported_runtime_error +from capa.features.extractors.base_extractor import FeatureExtractor, StaticFeatureExtractor logger = logging.getLogger("capa.show-features") @@ -117,14 +117,13 @@ def main(argv=None): args.format == capa.features.common.FORMAT_AUTO and capa.features.freeze.is_freeze(taste) ): with open(args.sample, "rb") as f: - extractor = capa.features.freeze.load(f.read()) + extractor: FeatureExtractor = capa.features.freeze.load(f.read()) else: should_save_workspace = os.environ.get("CAPA_SAVE_WORKSPACE") not in ("0", "no", "NO", "n", None) try: extractor = capa.main.get_extractor( args.sample, args.format, args.os, args.backend, sig_paths, should_save_workspace ) - assert isinstance(extractor, capa.features.extractors.base_extractor.StaticFeatureExtractor) except capa.exceptions.UnsupportedFormatError: capa.helpers.log_unsupported_format_error() return -1 @@ -132,6 +131,7 @@ def main(argv=None): log_unsupported_runtime_error() return -1 + assert isinstance(extractor, StaticFeatureExtractor) for feature, addr in extractor.extract_global_features(): print(f"global: {format_address(addr)}: {feature}") @@ -190,7 +190,7 @@ def ida_main(): return 0 -def print_features(functions, extractor: capa.features.extractors.base_extractor.FeatureExtractor): +def print_features(functions, extractor: StaticFeatureExtractor): for f in functions: if extractor.is_library_function(f.address): function_name = extractor.get_function_name(f.address)