diff --git a/capa/engine.py b/capa/engine.py index bd26f454..b5fbb412 100644 --- a/capa/engine.py +++ b/capa/engine.py @@ -8,7 +8,7 @@ import copy import collections -from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Mapping, Iterable +from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Mapping, Iterable, Iterator, Union, cast import capa.perf import capa.features.common @@ -60,17 +60,24 @@ class Statement: """ raise NotImplementedError() - def get_children(self): + def get_children(self) -> Iterator[Union["Statement", Feature]]: if hasattr(self, "child"): - yield self.child + # this really confuses mypy because the property may not exist + # since its defined in the subclasses. + child = self.child # type: ignore + assert isinstance(child, (Statement, Feature)) + yield child if hasattr(self, "children"): for child in getattr(self, "children"): + assert isinstance(child, (Statement, Feature)) yield child def replace_child(self, existing, new): if hasattr(self, "child"): - if self.child is existing: + # this really confuses mypy because the property may not exist + # since its defined in the subclasses. + if self.child is existing: # type: ignore self.child = new if hasattr(self, "children"): diff --git a/capa/features/common.py b/capa/features/common.py index a8dca781..dca0d03f 100644 --- a/capa/features/common.py +++ b/capa/features/common.py @@ -200,8 +200,9 @@ class Substring(String): # mapping from string value to list of locations. # will unique the locations later on. - matches = collections.defaultdict(list) + matches: collections.defaultdict[str, Set[Address]] = collections.defaultdict(set) + assert isinstance(self.value, str) for feature, locations in ctx.items(): if not isinstance(feature, (String,)): continue @@ -211,32 +212,29 @@ class Substring(String): raise ValueError("unexpected feature value type") if self.value in feature.value: - matches[feature.value].extend(locations) + matches[feature.value].update(locations) if short_circuit: # we found one matching string, thats sufficient to match. # don't collect other matching strings in this mode. break if matches: - # finalize: defaultdict -> dict - # which makes json serialization easier - matches = dict(matches) - # collect all locations locations = set() - for s in matches.keys(): - matches[s] = list(set(matches[s])) - locations.update(matches[s]) + for locs in matches.values(): + locations.update(locs) # unlike other features, we cannot return put a reference to `self` directly in a `Result`. # this is because `self` may match on many strings, so we can't stuff the matched value into it. # instead, return a new instance that has a reference to both the substring and the matched values. - return Result(True, _MatchedSubstring(self, matches), [], locations=locations) + return Result(True, _MatchedSubstring(self, dict(matches)), [], locations=locations) else: return Result(False, _MatchedSubstring(self, {}), []) def __str__(self): - return "substring(%s)" % self.value + v = self.value + assert isinstance(v, str) + return "substring(%s)" % v class _MatchedSubstring(Substring): @@ -261,6 +259,7 @@ class _MatchedSubstring(Substring): self.matches = matches def __str__(self): + assert isinstance(self.value, str) return 'substring("%s", matches = %s)' % ( self.value, ", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())), @@ -292,7 +291,7 @@ class Regex(String): # mapping from string value to list of locations. # will unique the locations later on. - matches = collections.defaultdict(list) + matches: collections.defaultdict[str, Set[Address]] = collections.defaultdict(set) for feature, locations in ctx.items(): if not isinstance(feature, (String,)): @@ -307,32 +306,28 @@ class Regex(String): # using this mode cleans is more convenient for rule authors, # so that they don't have to prefix/suffix their terms like: /.*foo.*/. if self.re.search(feature.value): - matches[feature.value].extend(locations) + matches[feature.value].update(locations) if short_circuit: # we found one matching string, thats sufficient to match. # don't collect other matching strings in this mode. break if matches: - # finalize: defaultdict -> dict - # which makes json serialization easier - matches = dict(matches) - # collect all locations locations = set() - for s in matches.keys(): - matches[s] = list(set(matches[s])) - locations.update(matches[s]) + for locs in matches.values(): + locations.update(locs) # unlike other features, we cannot return put a reference to `self` directly in a `Result`. # this is because `self` may match on many strings, so we can't stuff the matched value into it. # instead, return a new instance that has a reference to both the regex and the matched values. # see #262. - return Result(True, _MatchedRegex(self, matches), [], locations=locations) + return Result(True, _MatchedRegex(self, dict(matches)), [], locations=locations) else: return Result(False, _MatchedRegex(self, {}), []) def __str__(self): + assert isinstance(self.value, str) return "regex(string =~ %s)" % self.value @@ -358,6 +353,7 @@ class _MatchedRegex(Regex): self.matches = matches def __str__(self): + assert isinstance(self.value, str) return "regex(string =~ %s, matches = %s)" % ( self.value, ", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())), @@ -380,16 +376,19 @@ class Bytes(Feature): capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature.bytes"] += 1 + assert isinstance(self.value, bytes) for feature, locations in ctx.items(): if not isinstance(feature, (Bytes,)): continue + assert isinstance(feature.value, bytes) if feature.value.startswith(self.value): return Result(True, self, [], locations=locations) return Result(False, self, []) def get_value_str(self): + assert isinstance(self.value, bytes) return hex_string(bytes_to_str(self.value)) diff --git a/capa/features/extractors/dnfile/helpers.py b/capa/features/extractors/dnfile/helpers.py index 3fef794d..2c489c22 100644 --- a/capa/features/extractors/dnfile/helpers.py +++ b/capa/features/extractors/dnfile/helpers.py @@ -107,8 +107,18 @@ class DnUnmanagedMethod: return f"{module}.{method}" +def validate_has_dotnet(pe: dnfile.dnPE): + assert pe.net is not None + assert pe.net.mdtables is not None + assert pe.net.Flags is not None + + def resolve_dotnet_token(pe: dnfile.dnPE, token: Token) -> Any: """map generic token to string or table row""" + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.mdtables is not None + if isinstance(token, StringToken): user_string: Optional[str] = read_dotnet_user_string(pe, token) if user_string is None: @@ -143,6 +153,10 @@ def read_dotnet_method_body(pe: dnfile.dnPE, row: dnfile.mdtable.MethodDefRow) - def read_dotnet_user_string(pe: dnfile.dnPE, token: StringToken) -> Optional[str]: """read user string from #US stream""" + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.user_strings is not None + try: user_string: Optional[dnfile.stream.UserString] = pe.net.user_strings.get_us(token.rid) except UnicodeDecodeError as e: @@ -169,6 +183,11 @@ def get_dotnet_managed_imports(pe: dnfile.dnPE) -> Iterator[DnType]: TypeName (index into String heap) TypeNamespace (index into String heap) """ + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.mdtables is not None + assert pe.net.mdtables.MemberRef is not None + for (rid, row) in enumerate(iter_dotnet_table(pe, "MemberRef")): if not isinstance(row.Class.row, dnfile.mdtable.TypeRefRow): continue @@ -258,6 +277,11 @@ def get_dotnet_properties(pe: dnfile.dnPE) -> Iterator[DnType]: def get_dotnet_managed_method_bodies(pe: dnfile.dnPE) -> Iterator[Tuple[int, CilMethodBody]]: """get managed methods from MethodDef table""" + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.mdtables is not None + assert pe.net.mdtables.MethodDef is not None + if not hasattr(pe.net.mdtables, "MethodDef"): return @@ -307,15 +331,28 @@ def calculate_dotnet_token_value(table: int, rid: int) -> int: def is_dotnet_table_valid(pe: dnfile.dnPE, table_name: str) -> bool: + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.mdtables is not None + return bool(getattr(pe.net.mdtables, table_name, None)) def is_dotnet_mixed_mode(pe: dnfile.dnPE) -> bool: + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.Flags is not None + return not bool(pe.net.Flags.CLR_ILONLY) def iter_dotnet_table(pe: dnfile.dnPE, name: str) -> Iterator[Any]: + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.mdtables is not None + if not is_dotnet_table_valid(pe, name): return + for row in getattr(pe.net.mdtables, name): yield row diff --git a/capa/features/extractors/dnfile_.py b/capa/features/extractors/dnfile_.py index 998ea209..cf82bbce 100644 --- a/capa/features/extractors/dnfile_.py +++ b/capa/features/extractors/dnfile_.py @@ -19,9 +19,19 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[Feature, Address]]: yield OS(OS_ANY), NO_ADDRESS -def extract_file_arch(pe, **kwargs) -> Iterator[Tuple[Feature, Address]]: +def validate_has_dotnet(pe: dnfile.dnPE): + assert pe.net is not None + assert pe.net.mdtables is not None + assert pe.net.Flags is not None + + +def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Feature, Address]]: # to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020 # .NET 4.5 added option: any CPU, 32-bit preferred + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.Flags is not None + if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE: yield Arch(ARCH_I386), NO_ADDRESS elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS: @@ -71,6 +81,10 @@ class DnfileFeatureExtractor(FeatureExtractor): # self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT # True: native EP: Token # False: managed EP: RVA + validate_has_dotnet(self.pe) + assert self.pe.net is not None + assert self.pe.net.struct is not None + return self.pe.net.struct.EntryPointTokenOrRva def extract_global_features(self): @@ -83,13 +97,32 @@ class DnfileFeatureExtractor(FeatureExtractor): return bool(self.pe.net) def is_mixed_mode(self) -> bool: + validate_has_dotnet(self.pe) + assert self.pe is not None + assert self.pe.net is not None + assert self.pe.net.Flags is not None + return not bool(self.pe.net.Flags.CLR_ILONLY) def get_runtime_version(self) -> Tuple[int, int]: + validate_has_dotnet(self.pe) + assert self.pe is not None + assert self.pe.net is not None + assert self.pe.net.struct is not None + return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion def get_meta_version_string(self) -> str: - return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8") + validate_has_dotnet(self.pe) + assert self.pe.net is not None + assert self.pe.net.metadata is not None + assert self.pe.net.metadata.struct is not None + assert self.pe.net.metadata.struct.Version is not None + + vbuf = self.pe.net.metadata.struct.Version + assert isinstance(vbuf, bytes) + + return vbuf.rstrip(b"\x00").decode("utf-8") def get_functions(self): raise NotImplementedError("DnfileFeatureExtractor can only be used to extract file features") diff --git a/capa/features/extractors/dotnetfile.py b/capa/features/extractors/dotnetfile.py index ef6f9f07..e7bb67fc 100644 --- a/capa/features/extractors/dotnetfile.py +++ b/capa/features/extractors/dotnetfile.py @@ -40,6 +40,12 @@ def extract_file_format(**kwargs) -> Iterator[Tuple[Format, Address]]: yield Format(FORMAT_DOTNET), NO_ADDRESS +def validate_has_dotnet(pe: dnfile.dnPE): + assert pe.net is not None + assert pe.net.mdtables is not None + assert pe.net.Flags is not None + + def extract_file_import_names(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Import, Address]]: for method in get_dotnet_managed_imports(pe): # like System.IO.File::OpenRead @@ -78,6 +84,12 @@ def extract_file_namespace_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple def extract_file_class_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Class, Address]]: """emit class features from TypeRef and TypeDef tables""" + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.mdtables is not None + assert pe.net.mdtables.TypeDef is not None + assert pe.net.mdtables.TypeRef is not None + for (rid, row) in enumerate(iter_dotnet_table(pe, "TypeDef")): token = calculate_dotnet_token_value(pe.net.mdtables.TypeDef.number, rid + 1) yield Class(DnType.format_name(row.TypeName, namespace=row.TypeNamespace)), DNTokenAddress(token) @@ -94,6 +106,10 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[OS, Address]]: def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Arch, Address]]: # to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020 # .NET 4.5 added option: any CPU, 32-bit preferred + validate_has_dotnet(pe) + assert pe.net is not None + assert pe.net.Flags is not None + if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE: yield Arch(ARCH_I386), NO_ADDRESS elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS: @@ -155,6 +171,10 @@ class DotnetFileFeatureExtractor(FeatureExtractor): # self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT # True: native EP: Token # False: managed EP: RVA + validate_has_dotnet(self.pe) + assert self.pe.net is not None + assert self.pe.net.struct is not None + return self.pe.net.struct.EntryPointTokenOrRva def extract_global_features(self): @@ -170,10 +190,25 @@ class DotnetFileFeatureExtractor(FeatureExtractor): return is_dotnet_mixed_mode(self.pe) def get_runtime_version(self) -> Tuple[int, int]: + validate_has_dotnet(self.pe) + assert self.pe.net is not None + assert self.pe.net.struct is not None + assert self.pe.net.struct.MajorRuntimeVersion is not None + assert self.pe.net.struct.MinorRuntimeVersion is not None + return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion def get_meta_version_string(self) -> str: - return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8") + validate_has_dotnet(self.pe) + assert self.pe.net is not None + assert self.pe.net.metadata is not None + assert self.pe.net.metadata.struct is not None + assert self.pe.net.metadata.struct.Version is not None + + vbuf = self.pe.net.metadata.struct.Version + assert isinstance(vbuf, bytes) + + return vbuf.rstrip(b"\x00").decode("utf-8") def get_functions(self): raise NotImplementedError("DotnetFileFeatureExtractor can only be used to extract file features") diff --git a/capa/features/extractors/null.py b/capa/features/extractors/null.py index d5cf72ab..f8d6d077 100644 --- a/capa/features/extractors/null.py +++ b/capa/features/extractors/null.py @@ -52,26 +52,26 @@ class NullFeatureExtractor(FeatureExtractor): yield FunctionHandle(address, None) def extract_function_features(self, f): - for address, feature in self.functions.get(f.address, {}).features: + for address, feature in self.functions[f.address].features: yield feature, address def get_basic_blocks(self, f): - for address in sorted(self.functions.get(f.address, {}).basic_blocks.keys()): + for address in sorted(self.functions[f.address].basic_blocks.keys()): yield BBHandle(address, None) def extract_basic_block_features(self, f, bb): - for address, feature in self.functions.get(f.address, {}).basic_blocks.get(bb.address, {}).features: + for address, feature in self.functions[f.address].basic_blocks[bb.address].features: yield feature, address def get_instructions(self, f, bb): - for address in sorted(self.functions.get(f.address, {}).basic_blocks.get(bb.address, {}).instructions.keys()): + for address in sorted(self.functions[f.address].basic_blocks[bb.address].instructions.keys()): yield InsnHandle(address, None) def extract_insn_features(self, f, bb, insn): for address, feature in ( - self.functions.get(f.address, {}) - .basic_blocks.get(bb.address, {}) - .instructions.get(insn.address, {}) + self.functions[f.address] + .basic_blocks[bb.address] + .instructions[insn.address] .features ): yield feature, address diff --git a/capa/features/extractors/pefile.py b/capa/features/extractors/pefile.py index dbdf72ac..038200b8 100644 --- a/capa/features/extractors/pefile.py +++ b/capa/features/extractors/pefile.py @@ -133,7 +133,8 @@ def extract_file_features(pe, buf): """ for file_handler in FILE_HANDLERS: - for feature, va in file_handler(pe=pe, buf=buf): + # file_handler: type: (pe, bytes) -> Iterable[Tuple[Feature, Address]] + for feature, va in file_handler(pe=pe, buf=buf): # type: ignore yield feature, va @@ -160,7 +161,8 @@ def extract_global_features(pe, buf): Tuple[Feature, VA]: a feature and its location. """ for handler in GLOBAL_HANDLERS: - for feature, va in handler(pe=pe, buf=buf): + # file_handler: type: (pe, bytes) -> Iterable[Tuple[Feature, Address]] + for feature, va in handler(pe=pe, buf=buf): # type: ignore yield feature, va diff --git a/capa/features/extractors/smda/file.py b/capa/features/extractors/smda/file.py index f4bae925..fa2692ce 100644 --- a/capa/features/extractors/smda/file.py +++ b/capa/features/extractors/smda/file.py @@ -88,7 +88,8 @@ def extract_features(smda_report, buf): """ for file_handler in FILE_HANDLERS: - for feature, addr in file_handler(smda_report=smda_report, buf=buf): + # file_handler: type: (smda_report, bytes) -> Iterable[Tuple[Feature, Address]] + for feature, addr in file_handler(smda_report=smda_report, buf=buf): # type: ignore yield feature, addr diff --git a/capa/ida/plugin/form.py b/capa/ida/plugin/form.py index 9d101429..e6deecb7 100644 --- a/capa/ida/plugin/form.py +++ b/capa/ida/plugin/form.py @@ -11,7 +11,7 @@ import copy import logging import itertools import collections -from typing import Set, Dict, Optional +from typing import Set, Dict, Optional, List, Any import idaapi import ida_kernwin @@ -72,14 +72,14 @@ def trim_function_name(f, max_length=25): def find_func_features(fh: FunctionHandle, extractor): """ """ - func_features: Dict[Feature, Set] = collections.defaultdict(set) - bb_features: Dict[Address, Dict] = collections.defaultdict(dict) + func_features: Dict[Feature, Set[Address]] = collections.defaultdict(set) + bb_features: Dict[Address, Dict[Feature, Set[Address]]] = collections.defaultdict(dict) for (feature, addr) in extractor.extract_function_features(fh): func_features[feature].add(addr) for bbh in extractor.get_basic_blocks(fh): - _bb_features = collections.defaultdict(set) + _bb_features: Dict[Feature, Set[Address]] = collections.defaultdict(set) for (feature, addr) in extractor.extract_basic_block_features(fh, bbh): _bb_features[feature].add(addr) @@ -239,53 +239,52 @@ class CapaSettingsInputDialog(QtWidgets.QDialog): class CapaExplorerForm(idaapi.PluginForm): """form element for plugin interface""" - def __init__(self, name, option=Options.DEFAULT): + def __init__(self, name: str, option=Options.DEFAULT): """initialize form elements""" super().__init__() - self.form_title = name - self.process_total = 0 - self.process_count = 0 + self.form_title: str = name + self.process_total: int = 0 + self.process_count: int = 0 - self.parent = None - self.ida_hooks = None + self.parent: Any # QtWidget + self.ida_hooks: CapaExplorerIdaHooks self.doc: Optional[capa.render.result_document.ResultDocument] = None - self.rule_paths = None - self.rules_cache = None - self.ruleset_cache = None + self.rule_paths: Optional[List[str]] + self.rules_cache: Optional[List[capa.rules.Rule]] + self.ruleset_cache: Optional[capa.rules.RuleSet] # models - self.model_data = None - self.range_model_proxy = None - self.search_model_proxy = None + self.model_data: CapaExplorerDataModel + self.range_model_proxy: CapaExplorerRangeProxyModel + self.search_model_proxy: CapaExplorerSearchProxyModel # UI controls - self.view_limit_results_by_function = None - self.view_show_results_by_function = None - self.view_search_bar = None - self.view_tree = None - self.view_rulegen = None - self.view_tabs = None + self.view_limit_results_by_function: QtWidgets.QCheckBox + self.view_show_results_by_function: QtWidgets.QCheckBox + self.view_search_bar: QtWidgets.QLineEdit + self.view_tree: CapaExplorerQtreeView + self.view_tabs: QtWidgets.QTabWidget self.view_tab_rulegen = None - self.view_status_label = None - self.view_buttons = None - self.view_analyze_button = None - self.view_reset_button = None - self.view_settings_button = None - self.view_save_button = None + self.view_status_label: QtWidgets.QLabel + self.view_buttons: QtWidgets.QHBoxLayout + self.view_analyze_button: QtWidgets.QPushButton + self.view_reset_button: QtWidgets.QPushButton + self.view_settings_button: QtWidgets.QPushButton + self.view_save_button: QtWidgets.QPushButton - self.view_rulegen_preview = None - self.view_rulegen_features = None - self.view_rulegen_editor = None - self.view_rulegen_header_label = None - self.view_rulegen_search = None - self.view_rulegen_limit_features_by_ea = None - self.rulegen_current_function = None - self.rulegen_bb_features_cache = {} - self.rulegen_func_features_cache = {} - self.rulegen_file_features_cache = {} - self.view_rulegen_status_label = None + self.view_rulegen_preview: CapaExplorerRulegenPreview + self.view_rulegen_features: CapaExplorerRulegenFeatures + self.view_rulegen_editor: CapaExplorerRulegenEditor + self.view_rulegen_header_label: QtWidgets.QLabel + self.view_rulegen_search: QtWidgets.QLineEdit + self.view_rulegen_limit_features_by_ea: QtWidgets.QCheckBox + self.rulegen_current_function: Optional[FunctionHandle] + self.rulegen_bb_features_cache: Dict[Address, Dict[Feature, Set[Address]]] = {} + self.rulegen_func_features_cache: Dict[Feature, Set[Address]] = {} + self.rulegen_file_features_cache: Dict[Feature, Set[Address]] = {} + self.view_rulegen_status_label: QtWidgets.QLabel self.Show() @@ -762,6 +761,9 @@ class CapaExplorerForm(idaapi.PluginForm): if not self.load_capa_rules(): return False + assert self.rules_cache is not None + assert self.ruleset_cache is not None + if ida_kernwin.user_cancelled(): logger.info("User cancelled analysis.") return False @@ -822,6 +824,13 @@ class CapaExplorerForm(idaapi.PluginForm): return False try: + # either the results are cached and the doc already exists, + # or the doc was just created above + assert self.doc is not None + # same with rules cache, either it's cached or it was just loaded + assert self.rules_cache is not None + assert self.ruleset_cache is not None + self.model_data.render_capa_doc(self.doc, self.view_show_results_by_function.isChecked()) self.set_view_status_label( "capa rules directory: %s (%d rules)" % (settings.user[CAPA_SETTINGS_RULE_PATH], len(self.rules_cache)) @@ -871,6 +880,9 @@ class CapaExplorerForm(idaapi.PluginForm): else: logger.info('Using cached ruleset, click "Reset" to reload rules from disk.') + assert self.rules_cache is not None + assert self.ruleset_cache is not None + if ida_kernwin.user_cancelled(): logger.info("User cancelled analysis.") return False @@ -891,7 +903,8 @@ class CapaExplorerForm(idaapi.PluginForm): try: f = idaapi.get_func(idaapi.get_screen_ea()) if f: - fh: FunctionHandle = extractor.get_function(f.start_ea) + fh: Optional[FunctionHandle] = extractor.get_function(f.start_ea) + assert fh is not None self.rulegen_current_function = fh func_features, bb_features = find_func_features(fh, extractor) @@ -1053,6 +1066,8 @@ class CapaExplorerForm(idaapi.PluginForm): def update_rule_status(self, rule_text): """ """ + assert self.rules_cache is not None + if not self.view_rulegen_editor.invisibleRootItem().childCount(): self.set_rulegen_preview_border_neutral() self.view_rulegen_status_label.clear() @@ -1077,7 +1092,7 @@ class CapaExplorerForm(idaapi.PluginForm): rules.append(rule) try: - file_features = copy.copy(self.rulegen_file_features_cache) + file_features = copy.copy(dict(self.rulegen_file_features_cache)) if self.rulegen_current_function: func_matches, bb_matches = find_func_matches( self.rulegen_current_function, @@ -1093,7 +1108,7 @@ class CapaExplorerForm(idaapi.PluginForm): _, file_matches = capa.engine.match( capa.rules.RuleSet(list(capa.rules.get_rules_and_dependencies(rules, rule.name))).file_rules, file_features, - 0x0, + NO_ADDRESS ) except Exception as e: self.set_rulegen_status("Failed to match rule (%s)" % e) diff --git a/capa/ida/plugin/item.py b/capa/ida/plugin/item.py index 159333a4..ac349424 100644 --- a/capa/ida/plugin/item.py +++ b/capa/ida/plugin/item.py @@ -36,7 +36,7 @@ def ea_to_hex(ea): class CapaExplorerDataItem: """store data for CapaExplorerDataModel""" - def __init__(self, parent: "CapaExplorerDataItem", data: List[str], can_check=True): + def __init__(self, parent: Optional["CapaExplorerDataItem"], data: List[str], can_check=True): """initialize item""" self.pred = parent self._data = data @@ -110,7 +110,7 @@ class CapaExplorerDataItem: except IndexError: return None - def parent(self) -> "CapaExplorerDataItem": + def parent(self) -> Optional["CapaExplorerDataItem"]: """get parent""" return self.pred diff --git a/capa/ida/plugin/proxy.py b/capa/ida/plugin/proxy.py index ae490d87..e67147bd 100644 --- a/capa/ida/plugin/proxy.py +++ b/capa/ida/plugin/proxy.py @@ -92,7 +92,7 @@ class CapaExplorerRangeProxyModel(QtCore.QSortFilterProxyModel): @param parent: QModelIndex of parent """ # filter not set - if self.min_ea is None and self.max_ea is None: + if self.min_ea is None or self.max_ea is None: return True index = self.sourceModel().index(row, 0, parent) diff --git a/capa/ida/plugin/view.py b/capa/ida/plugin/view.py index 86505fb1..75abf59c 100644 --- a/capa/ida/plugin/view.py +++ b/capa/ida/plugin/view.py @@ -18,7 +18,7 @@ import capa.ida.helpers import capa.features.common import capa.features.basicblock from capa.ida.plugin.item import CapaExplorerFunctionItem -from capa.features.address import Address, _NoAddress +from capa.features.address import _NoAddress, AbsoluteVirtualAddress from capa.ida.plugin.model import CapaExplorerDataModel MAX_SECTION_SIZE = 750 @@ -1013,8 +1013,10 @@ class CapaExplorerRulegenFeatures(QtWidgets.QTreeWidget): self.parent_items = {} def format_address(e): - assert isinstance(e, Address) - return "%X" % e if not isinstance(e, _NoAddress) else "" + if isinstance(e, AbsoluteVirtualAddress): + return "%X" % int(e) + else: + return "" def format_feature(feature): """ """ diff --git a/capa/main.py b/capa/main.py index 1157a474..c973b61d 100644 --- a/capa/main.py +++ b/capa/main.py @@ -66,7 +66,7 @@ from capa.features.common import ( FORMAT_DOTNET, FORMAT_FREEZE, ) -from capa.features.address import NO_ADDRESS +from capa.features.address import NO_ADDRESS, Address from capa.features.extractors.base_extractor import BBHandle, InsnHandle, FunctionHandle, FeatureExtractor RULES_PATH_DEFAULT_STRING = "(embedded rules)" @@ -718,8 +718,8 @@ def compute_layout(rules, extractor, capabilities): otherwise, we may pollute the json document with a large amount of un-referenced data. """ - functions_by_bb = {} - bbs_by_function = {} + functions_by_bb: Dict[Address, Address] = {} + bbs_by_function: Dict[Address, List[Address]] = {} for f in extractor.get_functions(): bbs_by_function[f.address] = [] for bb in extractor.get_basic_blocks(f): @@ -1016,8 +1016,7 @@ def main(argv=None): return E_INVALID_FILE_TYPE try: - rules = get_rules(args.rules, disable_progress=args.quiet) - rules = capa.rules.RuleSet(rules) + rules = capa.rules.RuleSet(get_rules(args.rules, disable_progress=args.quiet)) logger.debug( "successfully loaded %s rules", @@ -1167,8 +1166,7 @@ def ida_main(): rules_path = os.path.join(get_default_root(), "rules") logger.debug("rule path: %s", rules_path) - rules = get_rules([rules_path]) - rules = capa.rules.RuleSet(rules) + rules = capa.rules.RuleSet(get_rules([rules_path])) meta = capa.ida.helpers.collect_metadata([rules_path]) diff --git a/capa/perf.py b/capa/perf.py index cb0e89ec..1d98f6c2 100644 --- a/capa/perf.py +++ b/capa/perf.py @@ -2,7 +2,7 @@ import collections from typing import Dict # this structure is unstable and may change before the next major release. -counters: Dict[str, int] = collections.Counter() +counters: collections.Counter[str] = collections.Counter() def reset(): diff --git a/capa/rules.py b/capa/rules.py index 8287eba0..c4d2ad77 100644 --- a/capa/rules.py +++ b/capa/rules.py @@ -634,7 +634,7 @@ class Rule: Returns: List[str]: names of rules upon which this rule depends. """ - deps = set([]) + deps: Set[str] = set([]) def rec(statement): if isinstance(statement, capa.features.common.MatchedRule): @@ -651,6 +651,7 @@ class Rule: deps.update(map(lambda r: r.name, namespaces[statement.value])) else: # not a namespace, assume its a rule name. + assert isinstance(statement.value, str) deps.add(statement.value) elif isinstance(statement, ceng.Statement): @@ -666,7 +667,11 @@ class Rule: def _extract_subscope_rules_rec(self, statement): if isinstance(statement, ceng.Statement): # for each child that is a subscope, - for subscope in filter(lambda statement: isinstance(statement, ceng.Subscope), statement.get_children()): + for child in statement.get_children(): + if not isinstance(child, ceng.Subscope): + continue + + subscope = child # create a new rule from it. # the name is a randomly generated, hopefully unique value. @@ -737,7 +742,7 @@ class Rule: return self.statement.evaluate(features, short_circuit=short_circuit) @classmethod - def from_dict(cls, d, definition): + def from_dict(cls, d, definition) -> "Rule": meta = d["rule"]["meta"] name = meta["name"] # if scope is not specified, default to function scope. @@ -771,14 +776,12 @@ class Rule: # prefer to use CLoader to be fast, see #306 # on Linux, make sure you install libyaml-dev or similar # on Windows, get WHLs from pyyaml.org/pypi - loader = yaml.CLoader logger.debug("using libyaml CLoader.") + return yaml.CLoader except: - loader = yaml.Loader logger.debug("unable to import libyaml CLoader, falling back to Python yaml parser.") logger.debug("this will be slower to load rules.") - - return loader + return yaml.Loader @staticmethod def _get_ruamel_yaml_parser(): @@ -790,8 +793,9 @@ class Rule: # use block mode, not inline json-like mode y.default_flow_style = False - # leave quotes unchanged - y.preserve_quotes = True + # leave quotes unchanged. + # manually verified this property exists, even if mypy complains. + y.preserve_quotes = True # type: ignore # indent lists by two spaces below their parent # @@ -802,12 +806,13 @@ class Rule: y.indent(sequence=2, offset=2) # avoid word wrapping - y.width = 4096 + # manually verified this property exists, even if mypy complains. + y.width = 4096 # type: ignore return y @classmethod - def from_yaml(cls, s, use_ruamel=False): + def from_yaml(cls, s, use_ruamel=False) -> "Rule": if use_ruamel: # ruamel enables nice formatting and doc roundtripping with comments doc = cls._get_ruamel_yaml_parser().load(s) @@ -817,7 +822,7 @@ class Rule: return cls.from_dict(doc, s) @classmethod - def from_yaml_file(cls, path, use_ruamel=False): + def from_yaml_file(cls, path, use_ruamel=False) -> "Rule": with open(path, "rb") as f: try: rule = cls.from_yaml(f.read().decode("utf-8"), use_ruamel=use_ruamel) @@ -832,7 +837,7 @@ class Rule: except pydantic.ValidationError as e: raise InvalidRuleWithPath(path, str(e)) from e - def to_yaml(self): + def to_yaml(self) -> str: # reformat the yaml document with a common style. # this includes: # - ordering the meta elements @@ -1261,7 +1266,7 @@ class RuleSet: return (easy_rules_by_feature, hard_rules) @staticmethod - def _get_rules_for_scope(rules, scope): + def _get_rules_for_scope(rules, scope) -> List[Rule]: """ given a collection of rules, collect the rules that are needed at the given scope. these rules are ordered topologically. @@ -1269,7 +1274,7 @@ class RuleSet: don't include auto-generated "subscope" rules. we want to include general "lib" rules here - even if they are not dependencies of other rules, see #398 """ - scope_rules = set([]) + scope_rules: Set[Rule] = set([]) # we need to process all rules, not just rules with the given scope. # this is because rules with a higher scope, e.g. file scope, may have subscope rules @@ -1283,7 +1288,7 @@ class RuleSet: return get_rules_with_scope(topologically_order_rules(list(scope_rules)), scope) @staticmethod - def _extract_subscope_rules(rules): + def _extract_subscope_rules(rules) -> List[Rule]: """ process the given sequence of rules. for each one, extract any embedded subscope rules into their own rule. diff --git a/rules b/rules index 2bc58afb..5ba70c97 160000 --- a/rules +++ b/rules @@ -1 +1 @@ -Subproject commit 2bc58afb5184a914ae13152df4ef09eb18ee3e79 +Subproject commit 5ba70c97d22dd59efcf29a128557e64213f7ace8 diff --git a/scripts/bulk-process.py b/scripts/bulk-process.py index 8ec23903..b57928c6 100644 --- a/scripts/bulk-process.py +++ b/scripts/bulk-process.py @@ -152,8 +152,7 @@ def main(argv=None): capa.main.handle_common_args(args) try: - rules = capa.main.get_rules(args.rules) - rules = capa.rules.RuleSet(rules) + rules = capa.rules.RuleSet(capa.main.get_rules(args.rules)) logger.info("successfully loaded %s rules", len(rules)) except (IOError, capa.rules.InvalidRule, capa.rules.InvalidRuleSet) as e: logger.error("%s", str(e)) diff --git a/scripts/capa2yara.py b/scripts/capa2yara.py index 06a1d031..9474347b 100644 --- a/scripts/capa2yara.py +++ b/scripts/capa2yara.py @@ -64,7 +64,6 @@ unsupported = ["characteristic", "mnemonic", "offset", "subscope", "Range"] # collect all converted rules to be able to check if we have needed sub rules for match: converted_rules = [] -count_incomplete = 0 default_tags = "CAPA " @@ -537,7 +536,8 @@ def output_unsupported_capa_rules(yaml, capa_rulename, url, reason): unsupported_capa_rules_names.write(url.encode("utf-8") + b"\n") -def convert_rules(rules, namespaces, cround): +def convert_rules(rules, namespaces, cround, make_priv): + count_incomplete = 0 for rule in rules.rules.values(): rule_name = convert_rule_name(rule.name) @@ -652,7 +652,6 @@ def convert_rules(rules, namespaces, cround): if meta_name and meta_value: yara_meta += "\t" + meta_name + ' = "' + meta_value + '"\n' - rule_name_bonus = "" if rule_comment: yara_meta += '\tcomment = "' + rule_comment + '"\n' yara_meta += '\tdate = "' + today + '"\n' @@ -679,12 +678,13 @@ def convert_rules(rules, namespaces, cround): # TODO: now the rule is finished and could be automatically checked with the capa-testfile(s) named in meta (doing it for all of them using yara-ci upload at the moment) output_yar(yara) converted_rules.append(rule_name) - global count_incomplete count_incomplete += incomplete else: output_unsupported_capa_rules(rule.to_yaml(), rule.name, url, yara_condition) pass + return count_incomplete + def main(argv=None): if argv is None: @@ -696,7 +696,6 @@ def main(argv=None): capa.main.install_common_args(parser, wanted={"tag"}) args = parser.parse_args(args=argv) - global make_priv make_priv = args.private if args.verbose: @@ -710,9 +709,9 @@ def main(argv=None): logging.getLogger("capa2yara").setLevel(level) try: - rules = capa.main.get_rules([args.rules], disable_progress=True) - namespaces = capa.rules.index_rules_by_namespace(list(rules)) - rules = capa.rules.RuleSet(rules) + rules_ = capa.main.get_rules([args.rules], disable_progress=True) + namespaces = capa.rules.index_rules_by_namespace(rules_) + rules = capa.rules.RuleSet(rules_) logger.info("successfully loaded %s rules (including subscope rules which will be ignored)", len(rules)) if args.tag: rules = rules.filter_rules_by_meta(args.tag) @@ -745,14 +744,15 @@ def main(argv=None): # do several rounds of converting rules because some rules for match: might not be converted in the 1st run num_rules = 9999999 cround = 0 + count_incomplete = 0 while num_rules != len(converted_rules) or cround < min_rounds: cround += 1 logger.info("doing convert_rules(), round: " + str(cround)) num_rules = len(converted_rules) - convert_rules(rules, namespaces, cround) + count_incomplete += convert_rules(rules, namespaces, cround, make_priv) # one last round to collect all unconverted rules - convert_rules(rules, namespaces, 9000) + count_incomplete += convert_rules(rules, namespaces, 9000, make_priv) stats = "\n// converted rules : " + str(len(converted_rules)) stats += "\n// among those are incomplete : " + str(count_incomplete) diff --git a/scripts/capa_as_library.py b/scripts/capa_as_library.py index 682d4dc6..2db6a644 100644 --- a/scripts/capa_as_library.py +++ b/scripts/capa_as_library.py @@ -172,7 +172,7 @@ def capa_details(rules_path, file_path, output_format="dictionary"): meta["analysis"].update(counts) meta["analysis"]["layout"] = capa.main.compute_layout(rules, extractor, capabilities) - capa_output = False + capa_output: Any = False if output_format == "dictionary": # ...as python dictionary, simplified as textable but in dictionary doc = rd.ResultDocument.from_capa(meta, rules, capabilities) diff --git a/scripts/detect-elf-os.py b/scripts/detect-elf-os.py index 63186ed8..078b80dd 100644 --- a/scripts/detect-elf-os.py +++ b/scripts/detect-elf-os.py @@ -28,7 +28,7 @@ def main(argv=None): if capa.helpers.is_runtime_ida(): from capa.ida.helpers import IDAIO - f: BinaryIO = IDAIO() + f: BinaryIO = IDAIO() # type: ignore else: if argv is None: diff --git a/scripts/lint.py b/scripts/lint.py index b3593f80..cd6e32cb 100644 --- a/scripts/lint.py +++ b/scripts/lint.py @@ -902,11 +902,15 @@ def redirecting_print_to_tqdm(): old_print(*args, **kwargs) try: - # Globaly replace print with new_print - inspect.builtins.print = new_print + # Globaly replace print with new_print. + # Verified this works manually on Python 3.11: + # >>> import inspect + # >>> inspect.builtins + # + inspect.builtins.print = new_print # type: ignore yield finally: - inspect.builtins.print = old_print + inspect.builtins.print = old_print # type: ignore def lint(ctx: Context): @@ -998,10 +1002,8 @@ def main(argv=None): time0 = time.time() try: - rules = capa.main.get_rules(args.rules, disable_progress=True) - rule_count = len(rules) - rules = capa.rules.RuleSet(rules) - logger.info("successfully loaded %s rules", rule_count) + rules = capa.rules.RuleSet(capa.main.get_rules(args.rules, disable_progress=True)) + logger.info("successfully loaded %s rules", len(rules)) if args.tag: rules = rules.filter_rules_by_meta(args.tag) logger.debug("selected %s rules", len(rules)) diff --git a/scripts/show-capabilities-by-function.py b/scripts/show-capabilities-by-function.py index 0c5ff361..d1773021 100644 --- a/scripts/show-capabilities-by-function.py +++ b/scripts/show-capabilities-by-function.py @@ -141,8 +141,7 @@ def main(argv=None): return -1 try: - rules = capa.main.get_rules(args.rules) - rules = capa.rules.RuleSet(rules) + rules = capa.rules.RuleSet(capa.main.get_rules(args.rules)) logger.info("successfully loaded %s rules", len(rules)) if args.tag: rules = rules.filter_rules_by_meta(args.tag) diff --git a/scripts/show-features.py b/scripts/show-features.py index 00c1eb05..d23a9a0a 100644 --- a/scripts/show-features.py +++ b/scripts/show-features.py @@ -136,7 +136,7 @@ def main(argv=None): for feature, addr in extractor.extract_file_features(): print("file: %s: %s" % (format_address(addr), feature)) - function_handles = extractor.get_functions() + function_handles = tuple(extractor.get_functions()) if args.function: if args.format == "freeze": @@ -173,7 +173,7 @@ def ida_main(): print("file: %s: %s" % (format_address(addr), feature)) return - function_handles = extractor.get_functions() + function_handles = tuple(extractor.get_functions()) if function: function_handles = tuple(filter(lambda fh: fh.inner.start_ea == function, function_handles)) diff --git a/tests/data b/tests/data index 0ffc189e..da6fed53 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 0ffc189eea6113d2dfc6355dacad8fbd78f9675d +Subproject commit da6fed53395be292ffec57a2732f0f6105c03487 diff --git a/tests/test_engine.py b/tests/test_engine.py index 26bb59ce..8fee9b92 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -8,58 +8,63 @@ from capa.engine import * from capa.features import * from capa.features.insn import * +import capa.features.address +ADDR1 = capa.features.address.AbsoluteVirtualAddress(0x401001) +ADDR2 = capa.features.address.AbsoluteVirtualAddress(0x401002) +ADDR3 = capa.features.address.AbsoluteVirtualAddress(0x401003) +ADDR4 = capa.features.address.AbsoluteVirtualAddress(0x401004) def test_number(): - assert Number(1).evaluate({Number(0): {1}}) == False - assert Number(1).evaluate({Number(1): {1}}) == True - assert Number(1).evaluate({Number(2): {1, 2}}) == False + assert Number(1).evaluate({Number(0): {ADDR1}}) == False + assert Number(1).evaluate({Number(1): {ADDR1}}) == True + assert Number(1).evaluate({Number(2): {ADDR1, ADDR2}}) == False def test_and(): - assert And([Number(1)]).evaluate({Number(0): {1}}) == False - assert And([Number(1)]).evaluate({Number(1): {1}}) == True - assert And([Number(1), Number(2)]).evaluate({Number(0): {1}}) == False - assert And([Number(1), Number(2)]).evaluate({Number(1): {1}}) == False - assert And([Number(1), Number(2)]).evaluate({Number(2): {1}}) == False - assert And([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {2}}) == True + assert And([Number(1)]).evaluate({Number(0): {ADDR1}}) == False + assert And([Number(1)]).evaluate({Number(1): {ADDR1}}) == True + assert And([Number(1), Number(2)]).evaluate({Number(0): {ADDR1}}) == False + assert And([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == False + assert And([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == False + assert And([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR2}}) == True def test_or(): - assert Or([Number(1)]).evaluate({Number(0): {1}}) == False - assert Or([Number(1)]).evaluate({Number(1): {1}}) == True - assert Or([Number(1), Number(2)]).evaluate({Number(0): {1}}) == False - assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True - assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True - assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {2}}) == True + assert Or([Number(1)]).evaluate({Number(0): {ADDR1}}) == False + assert Or([Number(1)]).evaluate({Number(1): {ADDR1}}) == True + assert Or([Number(1), Number(2)]).evaluate({Number(0): {ADDR1}}) == False + assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True + assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == True + assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR2}}) == True def test_not(): - assert Not(Number(1)).evaluate({Number(0): {1}}) == True - assert Not(Number(1)).evaluate({Number(1): {1}}) == False + assert Not(Number(1)).evaluate({Number(0): {ADDR1}}) == True + assert Not(Number(1)).evaluate({Number(1): {ADDR1}}) == False def test_some(): - assert Some(0, [Number(1)]).evaluate({Number(0): {1}}) == True - assert Some(1, [Number(1)]).evaluate({Number(0): {1}}) == False + assert Some(0, [Number(1)]).evaluate({Number(0): {ADDR1}}) == True + assert Some(1, [Number(1)]).evaluate({Number(0): {ADDR1}}) == False - assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}}) == False - assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}, Number(1): {1}}) == False - assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}}) == True + assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}}) == False + assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}}) == False + assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}}) == True assert ( Some(2, [Number(1), Number(2), Number(3)]).evaluate( - {Number(0): {1}, Number(1): {1}, Number(2): {1}, Number(3): {1}} + {Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}, Number(3): {ADDR1}} ) == True ) assert ( Some(2, [Number(1), Number(2), Number(3)]).evaluate( { - Number(0): {1}, - Number(1): {1}, - Number(2): {1}, - Number(3): {1}, - Number(4): {1}, + Number(0): {ADDR1}, + Number(1): {ADDR1}, + Number(2): {ADDR1}, + Number(3): {ADDR1}, + Number(4): {ADDR1}, } ) == True @@ -69,10 +74,10 @@ def test_some(): def test_complex(): assert True == Or( [And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5), Number(6)])])] - ).evaluate({Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}}) + ).evaluate({Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}) assert False == Or([And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5)])])]).evaluate( - {Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}} + {Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}} ) @@ -83,54 +88,54 @@ def test_range(): # unbounded range with matching feature should always match assert Range(Number(1)).evaluate({Number(1): {}}) == True - assert Range(Number(1)).evaluate({Number(1): {0}}) == True + assert Range(Number(1)).evaluate({Number(1): {ADDR1}}) == True # unbounded max - assert Range(Number(1), min=1).evaluate({Number(1): {0}}) == True - assert Range(Number(1), min=2).evaluate({Number(1): {0}}) == False - assert Range(Number(1), min=2).evaluate({Number(1): {0, 1}}) == True + assert Range(Number(1), min=1).evaluate({Number(1): {ADDR1}}) == True + assert Range(Number(1), min=2).evaluate({Number(1): {ADDR1}}) == False + assert Range(Number(1), min=2).evaluate({Number(1): {ADDR1, ADDR2}}) == True # unbounded min - assert Range(Number(1), max=0).evaluate({Number(1): {0}}) == False - assert Range(Number(1), max=1).evaluate({Number(1): {0}}) == True - assert Range(Number(1), max=2).evaluate({Number(1): {0}}) == True - assert Range(Number(1), max=2).evaluate({Number(1): {0, 1}}) == True - assert Range(Number(1), max=2).evaluate({Number(1): {0, 1, 3}}) == False + assert Range(Number(1), max=0).evaluate({Number(1): {ADDR1}}) == False + assert Range(Number(1), max=1).evaluate({Number(1): {ADDR1}}) == True + assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1}}) == True + assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2}}) == True + assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}}) == False # we can do an exact match by setting min==max assert Range(Number(1), min=1, max=1).evaluate({Number(1): {}}) == False - assert Range(Number(1), min=1, max=1).evaluate({Number(1): {1}}) == True - assert Range(Number(1), min=1, max=1).evaluate({Number(1): {1, 2}}) == False + assert Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1}}) == True + assert Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1, ADDR2}}) == False # bounded range assert Range(Number(1), min=1, max=3).evaluate({Number(1): {}}) == False - assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1}}) == True - assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2}}) == True - assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2, 3}}) == True - assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2, 3, 4}}) == False + assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1}}) == True + assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2}}) == True + assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}}) == True + assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3, ADDR4}}) == False def test_short_circuit(): - assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True + assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True # with short circuiting, only the children up until the first satisfied child are captured. - assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=True).children) == 1 - assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=False).children) == 2 + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}, short_circuit=True).children) == 1 + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}, short_circuit=False).children) == 2 def test_eval_order(): # base cases. - assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True - assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True + assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True + assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == True # with short circuiting, only the children up until the first satisfied child are captured. - assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children) == 1 - assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children) == 2 - assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {1}}).children) == 1 + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children) == 1 + assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children) == 2 + assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR1}}).children) == 1 # and its guaranteed that children are evaluated in order. - assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement == Number(1) - assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement != Number(2) + assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children[0].statement == Number(1) + assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children[0].statement != Number(2) - assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement == Number(2) - assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement != Number(1) + assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement == Number(2) + assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement != Number(1) diff --git a/tests/test_fmt.py b/tests/test_fmt.py index de96a1f4..1f37886c 100644 --- a/tests/test_fmt.py +++ b/tests/test_fmt.py @@ -98,7 +98,7 @@ def test_rule_reformat_order(): def test_rule_reformat_meta_update(): # test updating the rule content after parsing - rule = textwrap.dedent( + src = textwrap.dedent( """ rule: meta: @@ -116,7 +116,7 @@ def test_rule_reformat_meta_update(): """ ) - rule = capa.rules.Rule.from_yaml(rule) + rule = capa.rules.Rule.from_yaml(src) rule.name = "test rule" assert rule.to_yaml() == EXPECTED diff --git a/tests/test_match.py b/tests/test_match.py index 6fb319cd..2d8b9f2a 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -218,7 +218,7 @@ def test_match_matched_rules(): # the ordering of the rules must not matter, # the engine should match rules in an appropriate order. features, _ = match( - capa.rules.topologically_order_rules(reversed(rules)), + capa.rules.topologically_order_rules(list(reversed(rules))), {capa.features.insn.Number(100): {1}}, 0x0, ) diff --git a/tests/test_result_document.py b/tests/test_result_document.py index 8074e1cd..b98fadff 100644 --- a/tests/test_result_document.py +++ b/tests/test_result_document.py @@ -19,6 +19,7 @@ def test_optional_node_from_capa(): [], ) ) + assert isinstance(node, rdoc.StatementNode) assert isinstance(node.statement, rdoc.CompoundStatement) assert node.statement.type == rdoc.CompoundStatementType.OPTIONAL @@ -32,6 +33,7 @@ def test_some_node_from_capa(): ], ) ) + assert isinstance(node, rdoc.StatementNode) assert isinstance(node.statement, rdoc.SomeStatement) @@ -41,6 +43,7 @@ def test_range_node_from_capa(): capa.features.insn.Number(0), ) ) + assert isinstance(node, rdoc.StatementNode) assert isinstance(node.statement, rdoc.RangeStatement) @@ -51,6 +54,7 @@ def test_subscope_node_from_capa(): capa.features.insn.Number(0), ) ) + assert isinstance(node, rdoc.StatementNode) assert isinstance(node.statement, rdoc.SubscopeStatement) @@ -62,6 +66,7 @@ def test_and_node_from_capa(): ], ) ) + assert isinstance(node, rdoc.StatementNode) assert isinstance(node.statement, rdoc.CompoundStatement) assert node.statement.type == rdoc.CompoundStatementType.AND @@ -74,6 +79,7 @@ def test_or_node_from_capa(): ], ) ) + assert isinstance(node, rdoc.StatementNode) assert isinstance(node.statement, rdoc.CompoundStatement) assert node.statement.type == rdoc.CompoundStatementType.OR @@ -86,115 +92,138 @@ def test_not_node_from_capa(): ], ) ) + assert isinstance(node, rdoc.StatementNode) assert isinstance(node.statement, rdoc.CompoundStatement) assert node.statement.type == rdoc.CompoundStatementType.NOT def test_os_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.OS("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.OSFeature) def test_arch_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Arch("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.ArchFeature) def test_format_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Format("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.FormatFeature) def test_match_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.MatchedRule("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.MatchFeature) def test_characteristic_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Characteristic("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.CharacteristicFeature) def test_substring_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Substring("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.SubstringFeature) def test_regex_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Regex("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.RegexFeature) def test_class_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Class("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.ClassFeature) def test_namespace_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Namespace("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.NamespaceFeature) def test_bytes_node_from_capa(): node = rdoc.node_from_capa(capa.features.common.Bytes(b"")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.BytesFeature) def test_export_node_from_capa(): node = rdoc.node_from_capa(capa.features.file.Export("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.ExportFeature) def test_import_node_from_capa(): node = rdoc.node_from_capa(capa.features.file.Import("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.ImportFeature) def test_section_node_from_capa(): node = rdoc.node_from_capa(capa.features.file.Section("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.SectionFeature) def test_function_name_node_from_capa(): node = rdoc.node_from_capa(capa.features.file.FunctionName("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.FunctionNameFeature) def test_api_node_from_capa(): node = rdoc.node_from_capa(capa.features.insn.API("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.APIFeature) def test_property_node_from_capa(): node = rdoc.node_from_capa(capa.features.insn.Property("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.PropertyFeature) def test_number_node_from_capa(): node = rdoc.node_from_capa(capa.features.insn.Number(0)) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.NumberFeature) def test_offset_node_from_capa(): node = rdoc.node_from_capa(capa.features.insn.Offset(0)) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.OffsetFeature) def test_mnemonic_node_from_capa(): node = rdoc.node_from_capa(capa.features.insn.Mnemonic("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.MnemonicFeature) def test_operand_number_node_from_capa(): node = rdoc.node_from_capa(capa.features.insn.OperandNumber(0, 0)) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.OperandNumberFeature) def test_operand_offset_node_from_capa(): node = rdoc.node_from_capa(capa.features.insn.OperandOffset(0, 0)) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.OperandOffsetFeature) def test_basic_block_node_from_capa(): node = rdoc.node_from_capa(capa.features.basicblock.BasicBlock("")) + assert isinstance(node, rdoc.FeatureNode) assert isinstance(node.feature, frzf.BasicBlockFeature) diff --git a/tests/test_rules.py b/tests/test_rules.py index 61bef111..d5aea406 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -13,8 +13,10 @@ import pytest import capa.rules import capa.engine import capa.features.common +from capa.features.address import AbsoluteVirtualAddress from capa.features.file import FunctionName from capa.features.insn import Number, Offset, Property +from capa.engine import Or from capa.features.common import ( OS, OS_LINUX, @@ -29,12 +31,19 @@ from capa.features.common import ( Substring, FeatureAccess, ) +import capa.features.address + + +ADDR1 = capa.features.address.AbsoluteVirtualAddress(0x401001) +ADDR2 = capa.features.address.AbsoluteVirtualAddress(0x401002) +ADDR3 = capa.features.address.AbsoluteVirtualAddress(0x401003) +ADDR4 = capa.features.address.AbsoluteVirtualAddress(0x401004) def test_rule_ctor(): - r = capa.rules.Rule("test rule", capa.rules.FUNCTION_SCOPE, Number(1), {}) - assert r.evaluate({Number(0): {1}}) == False - assert r.evaluate({Number(1): {1}}) == True + r = capa.rules.Rule("test rule", capa.rules.FUNCTION_SCOPE, Or(Number(1)), {}) + assert r.evaluate({Number(0): {ADDR1}}) == False + assert r.evaluate({Number(1): {ADDR2}}) == True def test_rule_yaml(): @@ -56,10 +65,10 @@ def test_rule_yaml(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({Number(0): {1}}) == False - assert r.evaluate({Number(0): {1}, Number(1): {1}}) == False - assert r.evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}}) == True - assert r.evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}, Number(3): {1}}) == True + assert r.evaluate({Number(0): {ADDR1}}) == False + assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}}) == False + assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}}) == True + assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}, Number(3): {ADDR1}}) == True def test_rule_yaml_complex(): @@ -82,8 +91,8 @@ def test_rule_yaml_complex(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}}) == True - assert r.evaluate({Number(6): {1}, Number(7): {1}, Number(8): {1}}) == False + assert r.evaluate({Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}) == True + assert r.evaluate({Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}) == False def test_rule_descriptions(): @@ -160,8 +169,8 @@ def test_rule_yaml_not(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({Number(1): {1}}) == True - assert r.evaluate({Number(1): {1}, Number(2): {1}}) == False + assert r.evaluate({Number(1): {ADDR1}}) == True + assert r.evaluate({Number(1): {ADDR1}, Number(2): {ADDR1}}) == False def test_rule_yaml_count(): @@ -175,9 +184,9 @@ def test_rule_yaml_count(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({Number(100): {}}) == False - assert r.evaluate({Number(100): {1}}) == True - assert r.evaluate({Number(100): {1, 2}}) == False + assert r.evaluate({Number(100): set()}) == False + assert r.evaluate({Number(100): {ADDR1}}) == True + assert r.evaluate({Number(100): {ADDR1, ADDR2}}) == False def test_rule_yaml_count_range(): @@ -191,10 +200,10 @@ def test_rule_yaml_count_range(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({Number(100): {}}) == False - assert r.evaluate({Number(100): {1}}) == True - assert r.evaluate({Number(100): {1, 2}}) == True - assert r.evaluate({Number(100): {1, 2, 3}}) == False + assert r.evaluate({Number(100): set()}) == False + assert r.evaluate({Number(100): {ADDR1}}) == True + assert r.evaluate({Number(100): {ADDR1, ADDR2}}) == True + assert r.evaluate({Number(100): {ADDR1, ADDR2, ADDR3}}) == False def test_rule_yaml_count_string(): @@ -208,10 +217,10 @@ def test_rule_yaml_count_string(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({String("foo"): {}}) == False - assert r.evaluate({String("foo"): {1}}) == False - assert r.evaluate({String("foo"): {1, 2}}) == True - assert r.evaluate({String("foo"): {1, 2, 3}}) == False + assert r.evaluate({String("foo"): set()}) == False + assert r.evaluate({String("foo"): {ADDR1}}) == False + assert r.evaluate({String("foo"): {ADDR1, ADDR2}}) == True + assert r.evaluate({String("foo"): {ADDR1, ADDR2, ADDR3}}) == False def test_invalid_rule_feature(): @@ -481,11 +490,11 @@ def test_count_number_symbol(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({Number(2): {}}) == False - assert r.evaluate({Number(2): {1}}) == True - assert r.evaluate({Number(2): {1, 2}}) == False - assert r.evaluate({Number(0x100, description="symbol name"): {1}}) == False - assert r.evaluate({Number(0x100, description="symbol name"): {1, 2, 3}}) == True + assert r.evaluate({Number(2): set()}) == False + assert r.evaluate({Number(2): {ADDR1}}) == True + assert r.evaluate({Number(2): {ADDR1, ADDR2}}) == False + assert r.evaluate({Number(0x100, description="symbol name"): {ADDR1}}) == False + assert r.evaluate({Number(0x100, description="symbol name"): {ADDR1, ADDR2, ADDR3}}) == True def test_invalid_number(): @@ -567,11 +576,11 @@ def test_count_offset_symbol(): """ ) r = capa.rules.Rule.from_yaml(rule) - assert r.evaluate({Offset(2): {}}) == False - assert r.evaluate({Offset(2): {1}}) == True - assert r.evaluate({Offset(2): {1, 2}}) == False - assert r.evaluate({Offset(0x100, description="symbol name"): {1}}) == False - assert r.evaluate({Offset(0x100, description="symbol name"): {1, 2, 3}}) == True + assert r.evaluate({Offset(2): set()}) == False + assert r.evaluate({Offset(2): {ADDR1}}) == True + assert r.evaluate({Offset(2): {ADDR1, ADDR2}}) == False + assert r.evaluate({Offset(0x100, description="symbol name"): {ADDR1}}) == False + assert r.evaluate({Offset(0x100, description="symbol name"): {ADDR1, ADDR2, ADDR3}}) == True def test_invalid_offset(): @@ -966,10 +975,10 @@ def test_property_access(): """ ) ) - assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.READ): {1}}) == True + assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.READ): {ADDR1}}) == True - assert r.evaluate({Property("System.IO.FileInfo::Length"): {1}}) == False - assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.WRITE): {1}}) == False + assert r.evaluate({Property("System.IO.FileInfo::Length"): {ADDR1}}) == False + assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.WRITE): {ADDR1}}) == False def test_property_access_symbol(): @@ -986,7 +995,7 @@ def test_property_access_symbol(): ) assert ( r.evaluate( - {Property("System.IO.FileInfo::Length", access=FeatureAccess.READ, description="some property"): {1}} + {Property("System.IO.FileInfo::Length", access=FeatureAccess.READ, description="some property"): {ADDR1}} ) == True )