diff --git a/capa/engine.py b/capa/engine.py index ee103866..649d0367 100644 --- a/capa/engine.py +++ b/capa/engine.py @@ -102,14 +102,14 @@ class And(Statement): super().__init__(description=description) self.children = children - def evaluate(self, ctx, short_circuit=True): + def evaluate(self, features: FeatureSet, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.and"] += 1 if short_circuit: results = [] for child in self.children: - result = child.evaluate(ctx, short_circuit=short_circuit) + result = child.evaluate(features, short_circuit=short_circuit) results.append(result) if not result: # short circuit @@ -117,7 +117,7 @@ class And(Statement): return Result(True, self, results) else: - results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children] + results = [child.evaluate(features, short_circuit=short_circuit) for child in self.children] success = all(results) return Result(success, self, results) @@ -135,14 +135,14 @@ class Or(Statement): super().__init__(description=description) self.children = children - def evaluate(self, ctx, short_circuit=True): + def evaluate(self, features: FeatureSet, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.or"] += 1 if short_circuit: results = [] for child in self.children: - result = child.evaluate(ctx, short_circuit=short_circuit) + result = child.evaluate(features, short_circuit=short_circuit) results.append(result) if result: # short circuit as soon as we hit one match @@ -150,7 +150,7 @@ class Or(Statement): return Result(False, self, results) else: - results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children] + results = [child.evaluate(features, short_circuit=short_circuit) for child in self.children] success = any(results) return Result(success, self, results) @@ -162,11 +162,11 @@ class Not(Statement): super().__init__(description=description) self.child = child - def evaluate(self, ctx, short_circuit=True): + def evaluate(self, features: FeatureSet, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.not"] += 1 - results = [self.child.evaluate(ctx, short_circuit=short_circuit)] + results = [self.child.evaluate(features, short_circuit=short_circuit)] success = not results[0] return Result(success, self, results) @@ -185,7 +185,7 @@ class Some(Statement): self.count = count self.children = children - def evaluate(self, ctx, short_circuit=True): + def evaluate(self, features: FeatureSet, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.some"] += 1 @@ -193,7 +193,7 @@ class Some(Statement): results = [] satisfied_children_count = 0 for child in self.children: - result = child.evaluate(ctx, short_circuit=short_circuit) + result = child.evaluate(features, short_circuit=short_circuit) results.append(result) if result: satisfied_children_count += 1 @@ -204,7 +204,7 @@ class Some(Statement): return Result(False, self, results) else: - results = [child.evaluate(ctx, short_circuit=short_circuit) for child in self.children] + results = [child.evaluate(features, short_circuit=short_circuit) for child in self.children] # note that here we cast the child result as a bool # because we've overridden `__bool__` above. # @@ -214,7 +214,7 @@ class Some(Statement): class Range(Statement): - """match if the child is contained in the ctx set with a count in the given range.""" + """match if the child is contained in the feature set with a count in the given range.""" def __init__(self, child, min=None, max=None, description=None): super().__init__(description=description) @@ -222,15 +222,15 @@ class Range(Statement): self.min = min if min is not None else 0 self.max = max if max is not None else (1 << 64 - 1) - def evaluate(self, ctx, **kwargs): + def evaluate(self, features: FeatureSet, short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.range"] += 1 - count = len(ctx.get(self.child, [])) + count = len(features.get(self.child, [])) if self.min == 0 and count == 0: return Result(True, self, []) - return Result(self.min <= count <= self.max, self, [], locations=ctx.get(self.child)) + return Result(self.min <= count <= self.max, self, [], locations=features.get(self.child)) def __str__(self): if self.max == (1 << 64 - 1): @@ -250,7 +250,7 @@ class Subscope(Statement): self.scope = scope self.child = child - def evaluate(self, ctx, **kwargs): + def evaluate(self, features: FeatureSet, short_circuit=True): raise ValueError("cannot evaluate a subscope directly!") diff --git a/capa/features/common.py b/capa/features/common.py index 5cbe684d..4b02b5ce 100644 --- a/capa/features/common.py +++ b/capa/features/common.py @@ -166,10 +166,10 @@ class Feature(abc.ABC): # noqa: B024 def __repr__(self): return str(self) - def evaluate(self, ctx: Dict["Feature", Set[Address]], **kwargs) -> Result: + def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True) -> Result: capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature." + self.name] += 1 - return Result(self in ctx, self, [], locations=ctx.get(self, set())) + return Result(self in features, self, [], locations=features.get(self, set())) class MatchedRule(Feature): @@ -207,7 +207,7 @@ class Substring(String): super().__init__(value, description=description) self.value = value - def evaluate(self, ctx, short_circuit=True): + def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.substring"] += 1 @@ -216,7 +216,7 @@ class Substring(String): matches: typing.DefaultDict[str, Set[Address]] = collections.defaultdict(set) assert isinstance(self.value, str) - for feature, locations in ctx.items(): + for feature, locations in features.items(): if not isinstance(feature, (String,)): continue @@ -299,7 +299,7 @@ class Regex(String): f"invalid regular expression: {value} it should use Python syntax, try it at https://pythex.org" ) from exc - def evaluate(self, ctx, short_circuit=True): + def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.regex"] += 1 @@ -307,7 +307,7 @@ class Regex(String): # will unique the locations later on. matches: typing.DefaultDict[str, Set[Address]] = collections.defaultdict(set) - for feature, locations in ctx.items(): + for feature, locations in features.items(): if not isinstance(feature, (String,)): continue @@ -384,12 +384,12 @@ class Bytes(Feature): super().__init__(value, description=description) self.value = value - def evaluate(self, ctx, **kwargs): + def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.bytes"] += 1 assert isinstance(self.value, bytes) - for feature, locations in ctx.items(): + for feature, locations in features.items(): if not isinstance(feature, (Bytes,)): continue @@ -434,11 +434,11 @@ class OS(Feature): super().__init__(value, description=description) self.name = "os" - def evaluate(self, ctx, **kwargs): + def evaluate(self, features: "capa.engine.FeatureSet", short_circuit=True): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature." + self.name] += 1 - for feature, locations in ctx.items(): + for feature, locations in features.items(): if not isinstance(feature, (OS,)): continue diff --git a/tests/test_engine.py b/tests/test_engine.py index 785896a3..c942ed67 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -93,10 +93,10 @@ def test_complex(): def test_range(): # unbounded range, but no matching feature # since the lower bound is zero, and there are zero matches, ok - assert bool(Range(Number(1)).evaluate({Number(2): {}})) is True + assert bool(Range(Number(1)).evaluate({Number(2): {}})) is True # type: ignore # unbounded range with matching feature should always match - assert bool(Range(Number(1)).evaluate({Number(1): {}})) is True + assert bool(Range(Number(1)).evaluate({Number(1): {}})) is True # type: ignore assert bool(Range(Number(1)).evaluate({Number(1): {ADDR1}})) is True # unbounded max @@ -112,12 +112,12 @@ def test_range(): assert bool(Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}})) is False # we can do an exact match by setting min==max - assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {}})) is False + assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {}})) is False # type: ignore assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1}})) is True assert bool(Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1, ADDR2}})) is False # bounded range - assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {}})) is False + assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {}})) is False # type: ignore assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1}})) is True assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2}})) is True assert bool(Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}})) is True