lots of mypy

This commit is contained in:
Willi Ballenthin
2022-12-14 10:37:39 +01:00
parent b1d6fcd6c8
commit b819033da0
29 changed files with 410 additions and 233 deletions

View File

@@ -8,7 +8,7 @@
import copy import copy
import collections import collections
from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Mapping, Iterable from typing import TYPE_CHECKING, Set, Dict, List, Tuple, Mapping, Iterable, Iterator, Union, cast
import capa.perf import capa.perf
import capa.features.common import capa.features.common
@@ -60,17 +60,24 @@ class Statement:
""" """
raise NotImplementedError() raise NotImplementedError()
def get_children(self): def get_children(self) -> Iterator[Union["Statement", Feature]]:
if hasattr(self, "child"): if hasattr(self, "child"):
yield self.child # this really confuses mypy because the property may not exist
# since its defined in the subclasses.
child = self.child # type: ignore
assert isinstance(child, (Statement, Feature))
yield child
if hasattr(self, "children"): if hasattr(self, "children"):
for child in getattr(self, "children"): for child in getattr(self, "children"):
assert isinstance(child, (Statement, Feature))
yield child yield child
def replace_child(self, existing, new): def replace_child(self, existing, new):
if hasattr(self, "child"): if hasattr(self, "child"):
if self.child is existing: # this really confuses mypy because the property may not exist
# since its defined in the subclasses.
if self.child is existing: # type: ignore
self.child = new self.child = new
if hasattr(self, "children"): if hasattr(self, "children"):

View File

@@ -200,8 +200,9 @@ class Substring(String):
# mapping from string value to list of locations. # mapping from string value to list of locations.
# will unique the locations later on. # will unique the locations later on.
matches = collections.defaultdict(list) matches: collections.defaultdict[str, Set[Address]] = collections.defaultdict(set)
assert isinstance(self.value, str)
for feature, locations in ctx.items(): for feature, locations in ctx.items():
if not isinstance(feature, (String,)): if not isinstance(feature, (String,)):
continue continue
@@ -211,32 +212,29 @@ class Substring(String):
raise ValueError("unexpected feature value type") raise ValueError("unexpected feature value type")
if self.value in feature.value: if self.value in feature.value:
matches[feature.value].extend(locations) matches[feature.value].update(locations)
if short_circuit: if short_circuit:
# we found one matching string, thats sufficient to match. # we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode. # don't collect other matching strings in this mode.
break break
if matches: if matches:
# finalize: defaultdict -> dict
# which makes json serialization easier
matches = dict(matches)
# collect all locations # collect all locations
locations = set() locations = set()
for s in matches.keys(): for locs in matches.values():
matches[s] = list(set(matches[s])) locations.update(locs)
locations.update(matches[s])
# unlike other features, we cannot return put a reference to `self` directly in a `Result`. # unlike other features, we cannot return put a reference to `self` directly in a `Result`.
# this is because `self` may match on many strings, so we can't stuff the matched value into it. # this is because `self` may match on many strings, so we can't stuff the matched value into it.
# instead, return a new instance that has a reference to both the substring and the matched values. # instead, return a new instance that has a reference to both the substring and the matched values.
return Result(True, _MatchedSubstring(self, matches), [], locations=locations) return Result(True, _MatchedSubstring(self, dict(matches)), [], locations=locations)
else: else:
return Result(False, _MatchedSubstring(self, {}), []) return Result(False, _MatchedSubstring(self, {}), [])
def __str__(self): def __str__(self):
return "substring(%s)" % self.value v = self.value
assert isinstance(v, str)
return "substring(%s)" % v
class _MatchedSubstring(Substring): class _MatchedSubstring(Substring):
@@ -261,6 +259,7 @@ class _MatchedSubstring(Substring):
self.matches = matches self.matches = matches
def __str__(self): def __str__(self):
assert isinstance(self.value, str)
return 'substring("%s", matches = %s)' % ( return 'substring("%s", matches = %s)' % (
self.value, self.value,
", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())), ", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())),
@@ -292,7 +291,7 @@ class Regex(String):
# mapping from string value to list of locations. # mapping from string value to list of locations.
# will unique the locations later on. # will unique the locations later on.
matches = collections.defaultdict(list) matches: collections.defaultdict[str, Set[Address]] = collections.defaultdict(set)
for feature, locations in ctx.items(): for feature, locations in ctx.items():
if not isinstance(feature, (String,)): if not isinstance(feature, (String,)):
@@ -307,32 +306,28 @@ class Regex(String):
# using this mode cleans is more convenient for rule authors, # using this mode cleans is more convenient for rule authors,
# so that they don't have to prefix/suffix their terms like: /.*foo.*/. # so that they don't have to prefix/suffix their terms like: /.*foo.*/.
if self.re.search(feature.value): if self.re.search(feature.value):
matches[feature.value].extend(locations) matches[feature.value].update(locations)
if short_circuit: if short_circuit:
# we found one matching string, thats sufficient to match. # we found one matching string, thats sufficient to match.
# don't collect other matching strings in this mode. # don't collect other matching strings in this mode.
break break
if matches: if matches:
# finalize: defaultdict -> dict
# which makes json serialization easier
matches = dict(matches)
# collect all locations # collect all locations
locations = set() locations = set()
for s in matches.keys(): for locs in matches.values():
matches[s] = list(set(matches[s])) locations.update(locs)
locations.update(matches[s])
# unlike other features, we cannot return put a reference to `self` directly in a `Result`. # unlike other features, we cannot return put a reference to `self` directly in a `Result`.
# this is because `self` may match on many strings, so we can't stuff the matched value into it. # this is because `self` may match on many strings, so we can't stuff the matched value into it.
# instead, return a new instance that has a reference to both the regex and the matched values. # instead, return a new instance that has a reference to both the regex and the matched values.
# see #262. # see #262.
return Result(True, _MatchedRegex(self, matches), [], locations=locations) return Result(True, _MatchedRegex(self, dict(matches)), [], locations=locations)
else: else:
return Result(False, _MatchedRegex(self, {}), []) return Result(False, _MatchedRegex(self, {}), [])
def __str__(self): def __str__(self):
assert isinstance(self.value, str)
return "regex(string =~ %s)" % self.value return "regex(string =~ %s)" % self.value
@@ -358,6 +353,7 @@ class _MatchedRegex(Regex):
self.matches = matches self.matches = matches
def __str__(self): def __str__(self):
assert isinstance(self.value, str)
return "regex(string =~ %s, matches = %s)" % ( return "regex(string =~ %s, matches = %s)" % (
self.value, self.value,
", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())), ", ".join(map(lambda s: '"' + s + '"', (self.matches or {}).keys())),
@@ -380,16 +376,19 @@ class Bytes(Feature):
capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature"] += 1
capa.perf.counters["evaluate.feature.bytes"] += 1 capa.perf.counters["evaluate.feature.bytes"] += 1
assert isinstance(self.value, bytes)
for feature, locations in ctx.items(): for feature, locations in ctx.items():
if not isinstance(feature, (Bytes,)): if not isinstance(feature, (Bytes,)):
continue continue
assert isinstance(feature.value, bytes)
if feature.value.startswith(self.value): if feature.value.startswith(self.value):
return Result(True, self, [], locations=locations) return Result(True, self, [], locations=locations)
return Result(False, self, []) return Result(False, self, [])
def get_value_str(self): def get_value_str(self):
assert isinstance(self.value, bytes)
return hex_string(bytes_to_str(self.value)) return hex_string(bytes_to_str(self.value))

View File

@@ -107,8 +107,18 @@ class DnUnmanagedMethod:
return f"{module}.{method}" return f"{module}.{method}"
def validate_has_dotnet(pe: dnfile.dnPE):
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.Flags is not None
def resolve_dotnet_token(pe: dnfile.dnPE, token: Token) -> Any: def resolve_dotnet_token(pe: dnfile.dnPE, token: Token) -> Any:
"""map generic token to string or table row""" """map generic token to string or table row"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
if isinstance(token, StringToken): if isinstance(token, StringToken):
user_string: Optional[str] = read_dotnet_user_string(pe, token) user_string: Optional[str] = read_dotnet_user_string(pe, token)
if user_string is None: if user_string is None:
@@ -143,6 +153,10 @@ def read_dotnet_method_body(pe: dnfile.dnPE, row: dnfile.mdtable.MethodDefRow) -
def read_dotnet_user_string(pe: dnfile.dnPE, token: StringToken) -> Optional[str]: def read_dotnet_user_string(pe: dnfile.dnPE, token: StringToken) -> Optional[str]:
"""read user string from #US stream""" """read user string from #US stream"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.user_strings is not None
try: try:
user_string: Optional[dnfile.stream.UserString] = pe.net.user_strings.get_us(token.rid) user_string: Optional[dnfile.stream.UserString] = pe.net.user_strings.get_us(token.rid)
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
@@ -169,6 +183,11 @@ def get_dotnet_managed_imports(pe: dnfile.dnPE) -> Iterator[DnType]:
TypeName (index into String heap) TypeName (index into String heap)
TypeNamespace (index into String heap) TypeNamespace (index into String heap)
""" """
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.MemberRef is not None
for (rid, row) in enumerate(iter_dotnet_table(pe, "MemberRef")): for (rid, row) in enumerate(iter_dotnet_table(pe, "MemberRef")):
if not isinstance(row.Class.row, dnfile.mdtable.TypeRefRow): if not isinstance(row.Class.row, dnfile.mdtable.TypeRefRow):
continue continue
@@ -258,6 +277,11 @@ def get_dotnet_properties(pe: dnfile.dnPE) -> Iterator[DnType]:
def get_dotnet_managed_method_bodies(pe: dnfile.dnPE) -> Iterator[Tuple[int, CilMethodBody]]: def get_dotnet_managed_method_bodies(pe: dnfile.dnPE) -> Iterator[Tuple[int, CilMethodBody]]:
"""get managed methods from MethodDef table""" """get managed methods from MethodDef table"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.MethodDef is not None
if not hasattr(pe.net.mdtables, "MethodDef"): if not hasattr(pe.net.mdtables, "MethodDef"):
return return
@@ -307,15 +331,28 @@ def calculate_dotnet_token_value(table: int, rid: int) -> int:
def is_dotnet_table_valid(pe: dnfile.dnPE, table_name: str) -> bool: def is_dotnet_table_valid(pe: dnfile.dnPE, table_name: str) -> bool:
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
return bool(getattr(pe.net.mdtables, table_name, None)) return bool(getattr(pe.net.mdtables, table_name, None))
def is_dotnet_mixed_mode(pe: dnfile.dnPE) -> bool: def is_dotnet_mixed_mode(pe: dnfile.dnPE) -> bool:
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.Flags is not None
return not bool(pe.net.Flags.CLR_ILONLY) return not bool(pe.net.Flags.CLR_ILONLY)
def iter_dotnet_table(pe: dnfile.dnPE, name: str) -> Iterator[Any]: def iter_dotnet_table(pe: dnfile.dnPE, name: str) -> Iterator[Any]:
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
if not is_dotnet_table_valid(pe, name): if not is_dotnet_table_valid(pe, name):
return return
for row in getattr(pe.net.mdtables, name): for row in getattr(pe.net.mdtables, name):
yield row yield row

View File

@@ -19,9 +19,19 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[Feature, Address]]:
yield OS(OS_ANY), NO_ADDRESS yield OS(OS_ANY), NO_ADDRESS
def extract_file_arch(pe, **kwargs) -> Iterator[Tuple[Feature, Address]]: def validate_has_dotnet(pe: dnfile.dnPE):
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.Flags is not None
def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Feature, Address]]:
# to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020 # to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020
# .NET 4.5 added option: any CPU, 32-bit preferred # .NET 4.5 added option: any CPU, 32-bit preferred
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.Flags is not None
if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE: if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE:
yield Arch(ARCH_I386), NO_ADDRESS yield Arch(ARCH_I386), NO_ADDRESS
elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS: elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS:
@@ -71,6 +81,10 @@ class DnfileFeatureExtractor(FeatureExtractor):
# self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT # self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT
# True: native EP: Token # True: native EP: Token
# False: managed EP: RVA # False: managed EP: RVA
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.struct is not None
return self.pe.net.struct.EntryPointTokenOrRva return self.pe.net.struct.EntryPointTokenOrRva
def extract_global_features(self): def extract_global_features(self):
@@ -83,13 +97,32 @@ class DnfileFeatureExtractor(FeatureExtractor):
return bool(self.pe.net) return bool(self.pe.net)
def is_mixed_mode(self) -> bool: def is_mixed_mode(self) -> bool:
validate_has_dotnet(self.pe)
assert self.pe is not None
assert self.pe.net is not None
assert self.pe.net.Flags is not None
return not bool(self.pe.net.Flags.CLR_ILONLY) return not bool(self.pe.net.Flags.CLR_ILONLY)
def get_runtime_version(self) -> Tuple[int, int]: def get_runtime_version(self) -> Tuple[int, int]:
validate_has_dotnet(self.pe)
assert self.pe is not None
assert self.pe.net is not None
assert self.pe.net.struct is not None
return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion
def get_meta_version_string(self) -> str: def get_meta_version_string(self) -> str:
return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8") validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.metadata is not None
assert self.pe.net.metadata.struct is not None
assert self.pe.net.metadata.struct.Version is not None
vbuf = self.pe.net.metadata.struct.Version
assert isinstance(vbuf, bytes)
return vbuf.rstrip(b"\x00").decode("utf-8")
def get_functions(self): def get_functions(self):
raise NotImplementedError("DnfileFeatureExtractor can only be used to extract file features") raise NotImplementedError("DnfileFeatureExtractor can only be used to extract file features")

View File

@@ -40,6 +40,12 @@ def extract_file_format(**kwargs) -> Iterator[Tuple[Format, Address]]:
yield Format(FORMAT_DOTNET), NO_ADDRESS yield Format(FORMAT_DOTNET), NO_ADDRESS
def validate_has_dotnet(pe: dnfile.dnPE):
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.Flags is not None
def extract_file_import_names(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Import, Address]]: def extract_file_import_names(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Import, Address]]:
for method in get_dotnet_managed_imports(pe): for method in get_dotnet_managed_imports(pe):
# like System.IO.File::OpenRead # like System.IO.File::OpenRead
@@ -78,6 +84,12 @@ def extract_file_namespace_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple
def extract_file_class_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Class, Address]]: def extract_file_class_features(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Class, Address]]:
"""emit class features from TypeRef and TypeDef tables""" """emit class features from TypeRef and TypeDef tables"""
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.mdtables is not None
assert pe.net.mdtables.TypeDef is not None
assert pe.net.mdtables.TypeRef is not None
for (rid, row) in enumerate(iter_dotnet_table(pe, "TypeDef")): for (rid, row) in enumerate(iter_dotnet_table(pe, "TypeDef")):
token = calculate_dotnet_token_value(pe.net.mdtables.TypeDef.number, rid + 1) token = calculate_dotnet_token_value(pe.net.mdtables.TypeDef.number, rid + 1)
yield Class(DnType.format_name(row.TypeName, namespace=row.TypeNamespace)), DNTokenAddress(token) yield Class(DnType.format_name(row.TypeName, namespace=row.TypeNamespace)), DNTokenAddress(token)
@@ -94,6 +106,10 @@ def extract_file_os(**kwargs) -> Iterator[Tuple[OS, Address]]:
def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Arch, Address]]: def extract_file_arch(pe: dnfile.dnPE, **kwargs) -> Iterator[Tuple[Arch, Address]]:
# to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020 # to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020
# .NET 4.5 added option: any CPU, 32-bit preferred # .NET 4.5 added option: any CPU, 32-bit preferred
validate_has_dotnet(pe)
assert pe.net is not None
assert pe.net.Flags is not None
if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE: if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE:
yield Arch(ARCH_I386), NO_ADDRESS yield Arch(ARCH_I386), NO_ADDRESS
elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS: elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS:
@@ -155,6 +171,10 @@ class DotnetFileFeatureExtractor(FeatureExtractor):
# self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT # self.pe.net.Flags.CLT_NATIVE_ENTRYPOINT
# True: native EP: Token # True: native EP: Token
# False: managed EP: RVA # False: managed EP: RVA
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.struct is not None
return self.pe.net.struct.EntryPointTokenOrRva return self.pe.net.struct.EntryPointTokenOrRva
def extract_global_features(self): def extract_global_features(self):
@@ -170,10 +190,25 @@ class DotnetFileFeatureExtractor(FeatureExtractor):
return is_dotnet_mixed_mode(self.pe) return is_dotnet_mixed_mode(self.pe)
def get_runtime_version(self) -> Tuple[int, int]: def get_runtime_version(self) -> Tuple[int, int]:
validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.struct is not None
assert self.pe.net.struct.MajorRuntimeVersion is not None
assert self.pe.net.struct.MinorRuntimeVersion is not None
return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion
def get_meta_version_string(self) -> str: def get_meta_version_string(self) -> str:
return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8") validate_has_dotnet(self.pe)
assert self.pe.net is not None
assert self.pe.net.metadata is not None
assert self.pe.net.metadata.struct is not None
assert self.pe.net.metadata.struct.Version is not None
vbuf = self.pe.net.metadata.struct.Version
assert isinstance(vbuf, bytes)
return vbuf.rstrip(b"\x00").decode("utf-8")
def get_functions(self): def get_functions(self):
raise NotImplementedError("DotnetFileFeatureExtractor can only be used to extract file features") raise NotImplementedError("DotnetFileFeatureExtractor can only be used to extract file features")

View File

@@ -52,26 +52,26 @@ class NullFeatureExtractor(FeatureExtractor):
yield FunctionHandle(address, None) yield FunctionHandle(address, None)
def extract_function_features(self, f): def extract_function_features(self, f):
for address, feature in self.functions.get(f.address, {}).features: for address, feature in self.functions[f.address].features:
yield feature, address yield feature, address
def get_basic_blocks(self, f): def get_basic_blocks(self, f):
for address in sorted(self.functions.get(f.address, {}).basic_blocks.keys()): for address in sorted(self.functions[f.address].basic_blocks.keys()):
yield BBHandle(address, None) yield BBHandle(address, None)
def extract_basic_block_features(self, f, bb): def extract_basic_block_features(self, f, bb):
for address, feature in self.functions.get(f.address, {}).basic_blocks.get(bb.address, {}).features: for address, feature in self.functions[f.address].basic_blocks[bb.address].features:
yield feature, address yield feature, address
def get_instructions(self, f, bb): def get_instructions(self, f, bb):
for address in sorted(self.functions.get(f.address, {}).basic_blocks.get(bb.address, {}).instructions.keys()): for address in sorted(self.functions[f.address].basic_blocks[bb.address].instructions.keys()):
yield InsnHandle(address, None) yield InsnHandle(address, None)
def extract_insn_features(self, f, bb, insn): def extract_insn_features(self, f, bb, insn):
for address, feature in ( for address, feature in (
self.functions.get(f.address, {}) self.functions[f.address]
.basic_blocks.get(bb.address, {}) .basic_blocks[bb.address]
.instructions.get(insn.address, {}) .instructions[insn.address]
.features .features
): ):
yield feature, address yield feature, address

View File

@@ -133,7 +133,8 @@ def extract_file_features(pe, buf):
""" """
for file_handler in FILE_HANDLERS: for file_handler in FILE_HANDLERS:
for feature, va in file_handler(pe=pe, buf=buf): # file_handler: type: (pe, bytes) -> Iterable[Tuple[Feature, Address]]
for feature, va in file_handler(pe=pe, buf=buf): # type: ignore
yield feature, va yield feature, va
@@ -160,7 +161,8 @@ def extract_global_features(pe, buf):
Tuple[Feature, VA]: a feature and its location. Tuple[Feature, VA]: a feature and its location.
""" """
for handler in GLOBAL_HANDLERS: for handler in GLOBAL_HANDLERS:
for feature, va in handler(pe=pe, buf=buf): # file_handler: type: (pe, bytes) -> Iterable[Tuple[Feature, Address]]
for feature, va in handler(pe=pe, buf=buf): # type: ignore
yield feature, va yield feature, va

View File

@@ -88,7 +88,8 @@ def extract_features(smda_report, buf):
""" """
for file_handler in FILE_HANDLERS: for file_handler in FILE_HANDLERS:
for feature, addr in file_handler(smda_report=smda_report, buf=buf): # file_handler: type: (smda_report, bytes) -> Iterable[Tuple[Feature, Address]]
for feature, addr in file_handler(smda_report=smda_report, buf=buf): # type: ignore
yield feature, addr yield feature, addr

View File

@@ -11,7 +11,7 @@ import copy
import logging import logging
import itertools import itertools
import collections import collections
from typing import Set, Dict, Optional from typing import Set, Dict, Optional, List, Any
import idaapi import idaapi
import ida_kernwin import ida_kernwin
@@ -72,14 +72,14 @@ def trim_function_name(f, max_length=25):
def find_func_features(fh: FunctionHandle, extractor): def find_func_features(fh: FunctionHandle, extractor):
""" """ """ """
func_features: Dict[Feature, Set] = collections.defaultdict(set) func_features: Dict[Feature, Set[Address]] = collections.defaultdict(set)
bb_features: Dict[Address, Dict] = collections.defaultdict(dict) bb_features: Dict[Address, Dict[Feature, Set[Address]]] = collections.defaultdict(dict)
for (feature, addr) in extractor.extract_function_features(fh): for (feature, addr) in extractor.extract_function_features(fh):
func_features[feature].add(addr) func_features[feature].add(addr)
for bbh in extractor.get_basic_blocks(fh): for bbh in extractor.get_basic_blocks(fh):
_bb_features = collections.defaultdict(set) _bb_features: Dict[Feature, Set[Address]] = collections.defaultdict(set)
for (feature, addr) in extractor.extract_basic_block_features(fh, bbh): for (feature, addr) in extractor.extract_basic_block_features(fh, bbh):
_bb_features[feature].add(addr) _bb_features[feature].add(addr)
@@ -239,53 +239,52 @@ class CapaSettingsInputDialog(QtWidgets.QDialog):
class CapaExplorerForm(idaapi.PluginForm): class CapaExplorerForm(idaapi.PluginForm):
"""form element for plugin interface""" """form element for plugin interface"""
def __init__(self, name, option=Options.DEFAULT): def __init__(self, name: str, option=Options.DEFAULT):
"""initialize form elements""" """initialize form elements"""
super().__init__() super().__init__()
self.form_title = name self.form_title: str = name
self.process_total = 0 self.process_total: int = 0
self.process_count = 0 self.process_count: int = 0
self.parent = None self.parent: Any # QtWidget
self.ida_hooks = None self.ida_hooks: CapaExplorerIdaHooks
self.doc: Optional[capa.render.result_document.ResultDocument] = None self.doc: Optional[capa.render.result_document.ResultDocument] = None
self.rule_paths = None self.rule_paths: Optional[List[str]]
self.rules_cache = None self.rules_cache: Optional[List[capa.rules.Rule]]
self.ruleset_cache = None self.ruleset_cache: Optional[capa.rules.RuleSet]
# models # models
self.model_data = None self.model_data: CapaExplorerDataModel
self.range_model_proxy = None self.range_model_proxy: CapaExplorerRangeProxyModel
self.search_model_proxy = None self.search_model_proxy: CapaExplorerSearchProxyModel
# UI controls # UI controls
self.view_limit_results_by_function = None self.view_limit_results_by_function: QtWidgets.QCheckBox
self.view_show_results_by_function = None self.view_show_results_by_function: QtWidgets.QCheckBox
self.view_search_bar = None self.view_search_bar: QtWidgets.QLineEdit
self.view_tree = None self.view_tree: CapaExplorerQtreeView
self.view_rulegen = None self.view_tabs: QtWidgets.QTabWidget
self.view_tabs = None
self.view_tab_rulegen = None self.view_tab_rulegen = None
self.view_status_label = None self.view_status_label: QtWidgets.QLabel
self.view_buttons = None self.view_buttons: QtWidgets.QHBoxLayout
self.view_analyze_button = None self.view_analyze_button: QtWidgets.QPushButton
self.view_reset_button = None self.view_reset_button: QtWidgets.QPushButton
self.view_settings_button = None self.view_settings_button: QtWidgets.QPushButton
self.view_save_button = None self.view_save_button: QtWidgets.QPushButton
self.view_rulegen_preview = None self.view_rulegen_preview: CapaExplorerRulegenPreview
self.view_rulegen_features = None self.view_rulegen_features: CapaExplorerRulegenFeatures
self.view_rulegen_editor = None self.view_rulegen_editor: CapaExplorerRulegenEditor
self.view_rulegen_header_label = None self.view_rulegen_header_label: QtWidgets.QLabel
self.view_rulegen_search = None self.view_rulegen_search: QtWidgets.QLineEdit
self.view_rulegen_limit_features_by_ea = None self.view_rulegen_limit_features_by_ea: QtWidgets.QCheckBox
self.rulegen_current_function = None self.rulegen_current_function: Optional[FunctionHandle]
self.rulegen_bb_features_cache = {} self.rulegen_bb_features_cache: Dict[Address, Dict[Feature, Set[Address]]] = {}
self.rulegen_func_features_cache = {} self.rulegen_func_features_cache: Dict[Feature, Set[Address]] = {}
self.rulegen_file_features_cache = {} self.rulegen_file_features_cache: Dict[Feature, Set[Address]] = {}
self.view_rulegen_status_label = None self.view_rulegen_status_label: QtWidgets.QLabel
self.Show() self.Show()
@@ -762,6 +761,9 @@ class CapaExplorerForm(idaapi.PluginForm):
if not self.load_capa_rules(): if not self.load_capa_rules():
return False return False
assert self.rules_cache is not None
assert self.ruleset_cache is not None
if ida_kernwin.user_cancelled(): if ida_kernwin.user_cancelled():
logger.info("User cancelled analysis.") logger.info("User cancelled analysis.")
return False return False
@@ -822,6 +824,13 @@ class CapaExplorerForm(idaapi.PluginForm):
return False return False
try: try:
# either the results are cached and the doc already exists,
# or the doc was just created above
assert self.doc is not None
# same with rules cache, either it's cached or it was just loaded
assert self.rules_cache is not None
assert self.ruleset_cache is not None
self.model_data.render_capa_doc(self.doc, self.view_show_results_by_function.isChecked()) self.model_data.render_capa_doc(self.doc, self.view_show_results_by_function.isChecked())
self.set_view_status_label( self.set_view_status_label(
"capa rules directory: %s (%d rules)" % (settings.user[CAPA_SETTINGS_RULE_PATH], len(self.rules_cache)) "capa rules directory: %s (%d rules)" % (settings.user[CAPA_SETTINGS_RULE_PATH], len(self.rules_cache))
@@ -871,6 +880,9 @@ class CapaExplorerForm(idaapi.PluginForm):
else: else:
logger.info('Using cached ruleset, click "Reset" to reload rules from disk.') logger.info('Using cached ruleset, click "Reset" to reload rules from disk.')
assert self.rules_cache is not None
assert self.ruleset_cache is not None
if ida_kernwin.user_cancelled(): if ida_kernwin.user_cancelled():
logger.info("User cancelled analysis.") logger.info("User cancelled analysis.")
return False return False
@@ -891,7 +903,8 @@ class CapaExplorerForm(idaapi.PluginForm):
try: try:
f = idaapi.get_func(idaapi.get_screen_ea()) f = idaapi.get_func(idaapi.get_screen_ea())
if f: if f:
fh: FunctionHandle = extractor.get_function(f.start_ea) fh: Optional[FunctionHandle] = extractor.get_function(f.start_ea)
assert fh is not None
self.rulegen_current_function = fh self.rulegen_current_function = fh
func_features, bb_features = find_func_features(fh, extractor) func_features, bb_features = find_func_features(fh, extractor)
@@ -1053,6 +1066,8 @@ class CapaExplorerForm(idaapi.PluginForm):
def update_rule_status(self, rule_text): def update_rule_status(self, rule_text):
""" """ """ """
assert self.rules_cache is not None
if not self.view_rulegen_editor.invisibleRootItem().childCount(): if not self.view_rulegen_editor.invisibleRootItem().childCount():
self.set_rulegen_preview_border_neutral() self.set_rulegen_preview_border_neutral()
self.view_rulegen_status_label.clear() self.view_rulegen_status_label.clear()
@@ -1077,7 +1092,7 @@ class CapaExplorerForm(idaapi.PluginForm):
rules.append(rule) rules.append(rule)
try: try:
file_features = copy.copy(self.rulegen_file_features_cache) file_features = copy.copy(dict(self.rulegen_file_features_cache))
if self.rulegen_current_function: if self.rulegen_current_function:
func_matches, bb_matches = find_func_matches( func_matches, bb_matches = find_func_matches(
self.rulegen_current_function, self.rulegen_current_function,
@@ -1093,7 +1108,7 @@ class CapaExplorerForm(idaapi.PluginForm):
_, file_matches = capa.engine.match( _, file_matches = capa.engine.match(
capa.rules.RuleSet(list(capa.rules.get_rules_and_dependencies(rules, rule.name))).file_rules, capa.rules.RuleSet(list(capa.rules.get_rules_and_dependencies(rules, rule.name))).file_rules,
file_features, file_features,
0x0, NO_ADDRESS
) )
except Exception as e: except Exception as e:
self.set_rulegen_status("Failed to match rule (%s)" % e) self.set_rulegen_status("Failed to match rule (%s)" % e)

View File

@@ -36,7 +36,7 @@ def ea_to_hex(ea):
class CapaExplorerDataItem: class CapaExplorerDataItem:
"""store data for CapaExplorerDataModel""" """store data for CapaExplorerDataModel"""
def __init__(self, parent: "CapaExplorerDataItem", data: List[str], can_check=True): def __init__(self, parent: Optional["CapaExplorerDataItem"], data: List[str], can_check=True):
"""initialize item""" """initialize item"""
self.pred = parent self.pred = parent
self._data = data self._data = data
@@ -110,7 +110,7 @@ class CapaExplorerDataItem:
except IndexError: except IndexError:
return None return None
def parent(self) -> "CapaExplorerDataItem": def parent(self) -> Optional["CapaExplorerDataItem"]:
"""get parent""" """get parent"""
return self.pred return self.pred

View File

@@ -92,7 +92,7 @@ class CapaExplorerRangeProxyModel(QtCore.QSortFilterProxyModel):
@param parent: QModelIndex of parent @param parent: QModelIndex of parent
""" """
# filter not set # filter not set
if self.min_ea is None and self.max_ea is None: if self.min_ea is None or self.max_ea is None:
return True return True
index = self.sourceModel().index(row, 0, parent) index = self.sourceModel().index(row, 0, parent)

View File

@@ -18,7 +18,7 @@ import capa.ida.helpers
import capa.features.common import capa.features.common
import capa.features.basicblock import capa.features.basicblock
from capa.ida.plugin.item import CapaExplorerFunctionItem from capa.ida.plugin.item import CapaExplorerFunctionItem
from capa.features.address import Address, _NoAddress from capa.features.address import _NoAddress, AbsoluteVirtualAddress
from capa.ida.plugin.model import CapaExplorerDataModel from capa.ida.plugin.model import CapaExplorerDataModel
MAX_SECTION_SIZE = 750 MAX_SECTION_SIZE = 750
@@ -1013,8 +1013,10 @@ class CapaExplorerRulegenFeatures(QtWidgets.QTreeWidget):
self.parent_items = {} self.parent_items = {}
def format_address(e): def format_address(e):
assert isinstance(e, Address) if isinstance(e, AbsoluteVirtualAddress):
return "%X" % e if not isinstance(e, _NoAddress) else "" return "%X" % int(e)
else:
return ""
def format_feature(feature): def format_feature(feature):
""" """ """ """

View File

@@ -66,7 +66,7 @@ from capa.features.common import (
FORMAT_DOTNET, FORMAT_DOTNET,
FORMAT_FREEZE, FORMAT_FREEZE,
) )
from capa.features.address import NO_ADDRESS from capa.features.address import NO_ADDRESS, Address
from capa.features.extractors.base_extractor import BBHandle, InsnHandle, FunctionHandle, FeatureExtractor from capa.features.extractors.base_extractor import BBHandle, InsnHandle, FunctionHandle, FeatureExtractor
RULES_PATH_DEFAULT_STRING = "(embedded rules)" RULES_PATH_DEFAULT_STRING = "(embedded rules)"
@@ -718,8 +718,8 @@ def compute_layout(rules, extractor, capabilities):
otherwise, we may pollute the json document with otherwise, we may pollute the json document with
a large amount of un-referenced data. a large amount of un-referenced data.
""" """
functions_by_bb = {} functions_by_bb: Dict[Address, Address] = {}
bbs_by_function = {} bbs_by_function: Dict[Address, List[Address]] = {}
for f in extractor.get_functions(): for f in extractor.get_functions():
bbs_by_function[f.address] = [] bbs_by_function[f.address] = []
for bb in extractor.get_basic_blocks(f): for bb in extractor.get_basic_blocks(f):
@@ -1016,8 +1016,7 @@ def main(argv=None):
return E_INVALID_FILE_TYPE return E_INVALID_FILE_TYPE
try: try:
rules = get_rules(args.rules, disable_progress=args.quiet) rules = capa.rules.RuleSet(get_rules(args.rules, disable_progress=args.quiet))
rules = capa.rules.RuleSet(rules)
logger.debug( logger.debug(
"successfully loaded %s rules", "successfully loaded %s rules",
@@ -1167,8 +1166,7 @@ def ida_main():
rules_path = os.path.join(get_default_root(), "rules") rules_path = os.path.join(get_default_root(), "rules")
logger.debug("rule path: %s", rules_path) logger.debug("rule path: %s", rules_path)
rules = get_rules([rules_path]) rules = capa.rules.RuleSet(get_rules([rules_path]))
rules = capa.rules.RuleSet(rules)
meta = capa.ida.helpers.collect_metadata([rules_path]) meta = capa.ida.helpers.collect_metadata([rules_path])

View File

@@ -2,7 +2,7 @@ import collections
from typing import Dict from typing import Dict
# this structure is unstable and may change before the next major release. # this structure is unstable and may change before the next major release.
counters: Dict[str, int] = collections.Counter() counters: collections.Counter[str] = collections.Counter()
def reset(): def reset():

View File

@@ -634,7 +634,7 @@ class Rule:
Returns: Returns:
List[str]: names of rules upon which this rule depends. List[str]: names of rules upon which this rule depends.
""" """
deps = set([]) deps: Set[str] = set([])
def rec(statement): def rec(statement):
if isinstance(statement, capa.features.common.MatchedRule): if isinstance(statement, capa.features.common.MatchedRule):
@@ -651,6 +651,7 @@ class Rule:
deps.update(map(lambda r: r.name, namespaces[statement.value])) deps.update(map(lambda r: r.name, namespaces[statement.value]))
else: else:
# not a namespace, assume its a rule name. # not a namespace, assume its a rule name.
assert isinstance(statement.value, str)
deps.add(statement.value) deps.add(statement.value)
elif isinstance(statement, ceng.Statement): elif isinstance(statement, ceng.Statement):
@@ -666,7 +667,11 @@ class Rule:
def _extract_subscope_rules_rec(self, statement): def _extract_subscope_rules_rec(self, statement):
if isinstance(statement, ceng.Statement): if isinstance(statement, ceng.Statement):
# for each child that is a subscope, # for each child that is a subscope,
for subscope in filter(lambda statement: isinstance(statement, ceng.Subscope), statement.get_children()): for child in statement.get_children():
if not isinstance(child, ceng.Subscope):
continue
subscope = child
# create a new rule from it. # create a new rule from it.
# the name is a randomly generated, hopefully unique value. # the name is a randomly generated, hopefully unique value.
@@ -737,7 +742,7 @@ class Rule:
return self.statement.evaluate(features, short_circuit=short_circuit) return self.statement.evaluate(features, short_circuit=short_circuit)
@classmethod @classmethod
def from_dict(cls, d, definition): def from_dict(cls, d, definition) -> "Rule":
meta = d["rule"]["meta"] meta = d["rule"]["meta"]
name = meta["name"] name = meta["name"]
# if scope is not specified, default to function scope. # if scope is not specified, default to function scope.
@@ -771,14 +776,12 @@ class Rule:
# prefer to use CLoader to be fast, see #306 # prefer to use CLoader to be fast, see #306
# on Linux, make sure you install libyaml-dev or similar # on Linux, make sure you install libyaml-dev or similar
# on Windows, get WHLs from pyyaml.org/pypi # on Windows, get WHLs from pyyaml.org/pypi
loader = yaml.CLoader
logger.debug("using libyaml CLoader.") logger.debug("using libyaml CLoader.")
return yaml.CLoader
except: except:
loader = yaml.Loader
logger.debug("unable to import libyaml CLoader, falling back to Python yaml parser.") logger.debug("unable to import libyaml CLoader, falling back to Python yaml parser.")
logger.debug("this will be slower to load rules.") logger.debug("this will be slower to load rules.")
return yaml.Loader
return loader
@staticmethod @staticmethod
def _get_ruamel_yaml_parser(): def _get_ruamel_yaml_parser():
@@ -790,8 +793,9 @@ class Rule:
# use block mode, not inline json-like mode # use block mode, not inline json-like mode
y.default_flow_style = False y.default_flow_style = False
# leave quotes unchanged # leave quotes unchanged.
y.preserve_quotes = True # manually verified this property exists, even if mypy complains.
y.preserve_quotes = True # type: ignore
# indent lists by two spaces below their parent # indent lists by two spaces below their parent
# #
@@ -802,12 +806,13 @@ class Rule:
y.indent(sequence=2, offset=2) y.indent(sequence=2, offset=2)
# avoid word wrapping # avoid word wrapping
y.width = 4096 # manually verified this property exists, even if mypy complains.
y.width = 4096 # type: ignore
return y return y
@classmethod @classmethod
def from_yaml(cls, s, use_ruamel=False): def from_yaml(cls, s, use_ruamel=False) -> "Rule":
if use_ruamel: if use_ruamel:
# ruamel enables nice formatting and doc roundtripping with comments # ruamel enables nice formatting and doc roundtripping with comments
doc = cls._get_ruamel_yaml_parser().load(s) doc = cls._get_ruamel_yaml_parser().load(s)
@@ -817,7 +822,7 @@ class Rule:
return cls.from_dict(doc, s) return cls.from_dict(doc, s)
@classmethod @classmethod
def from_yaml_file(cls, path, use_ruamel=False): def from_yaml_file(cls, path, use_ruamel=False) -> "Rule":
with open(path, "rb") as f: with open(path, "rb") as f:
try: try:
rule = cls.from_yaml(f.read().decode("utf-8"), use_ruamel=use_ruamel) rule = cls.from_yaml(f.read().decode("utf-8"), use_ruamel=use_ruamel)
@@ -832,7 +837,7 @@ class Rule:
except pydantic.ValidationError as e: except pydantic.ValidationError as e:
raise InvalidRuleWithPath(path, str(e)) from e raise InvalidRuleWithPath(path, str(e)) from e
def to_yaml(self): def to_yaml(self) -> str:
# reformat the yaml document with a common style. # reformat the yaml document with a common style.
# this includes: # this includes:
# - ordering the meta elements # - ordering the meta elements
@@ -1261,7 +1266,7 @@ class RuleSet:
return (easy_rules_by_feature, hard_rules) return (easy_rules_by_feature, hard_rules)
@staticmethod @staticmethod
def _get_rules_for_scope(rules, scope): def _get_rules_for_scope(rules, scope) -> List[Rule]:
""" """
given a collection of rules, collect the rules that are needed at the given scope. given a collection of rules, collect the rules that are needed at the given scope.
these rules are ordered topologically. these rules are ordered topologically.
@@ -1269,7 +1274,7 @@ class RuleSet:
don't include auto-generated "subscope" rules. don't include auto-generated "subscope" rules.
we want to include general "lib" rules here - even if they are not dependencies of other rules, see #398 we want to include general "lib" rules here - even if they are not dependencies of other rules, see #398
""" """
scope_rules = set([]) scope_rules: Set[Rule] = set([])
# we need to process all rules, not just rules with the given scope. # we need to process all rules, not just rules with the given scope.
# this is because rules with a higher scope, e.g. file scope, may have subscope rules # this is because rules with a higher scope, e.g. file scope, may have subscope rules
@@ -1283,7 +1288,7 @@ class RuleSet:
return get_rules_with_scope(topologically_order_rules(list(scope_rules)), scope) return get_rules_with_scope(topologically_order_rules(list(scope_rules)), scope)
@staticmethod @staticmethod
def _extract_subscope_rules(rules): def _extract_subscope_rules(rules) -> List[Rule]:
""" """
process the given sequence of rules. process the given sequence of rules.
for each one, extract any embedded subscope rules into their own rule. for each one, extract any embedded subscope rules into their own rule.

2
rules

Submodule rules updated: 2bc58afb51...5ba70c97d2

View File

@@ -152,8 +152,7 @@ def main(argv=None):
capa.main.handle_common_args(args) capa.main.handle_common_args(args)
try: try:
rules = capa.main.get_rules(args.rules) rules = capa.rules.RuleSet(capa.main.get_rules(args.rules))
rules = capa.rules.RuleSet(rules)
logger.info("successfully loaded %s rules", len(rules)) logger.info("successfully loaded %s rules", len(rules))
except (IOError, capa.rules.InvalidRule, capa.rules.InvalidRuleSet) as e: except (IOError, capa.rules.InvalidRule, capa.rules.InvalidRuleSet) as e:
logger.error("%s", str(e)) logger.error("%s", str(e))

View File

@@ -64,7 +64,6 @@ unsupported = ["characteristic", "mnemonic", "offset", "subscope", "Range"]
# collect all converted rules to be able to check if we have needed sub rules for match: # collect all converted rules to be able to check if we have needed sub rules for match:
converted_rules = [] converted_rules = []
count_incomplete = 0
default_tags = "CAPA " default_tags = "CAPA "
@@ -537,7 +536,8 @@ def output_unsupported_capa_rules(yaml, capa_rulename, url, reason):
unsupported_capa_rules_names.write(url.encode("utf-8") + b"\n") unsupported_capa_rules_names.write(url.encode("utf-8") + b"\n")
def convert_rules(rules, namespaces, cround): def convert_rules(rules, namespaces, cround, make_priv):
count_incomplete = 0
for rule in rules.rules.values(): for rule in rules.rules.values():
rule_name = convert_rule_name(rule.name) rule_name = convert_rule_name(rule.name)
@@ -652,7 +652,6 @@ def convert_rules(rules, namespaces, cround):
if meta_name and meta_value: if meta_name and meta_value:
yara_meta += "\t" + meta_name + ' = "' + meta_value + '"\n' yara_meta += "\t" + meta_name + ' = "' + meta_value + '"\n'
rule_name_bonus = ""
if rule_comment: if rule_comment:
yara_meta += '\tcomment = "' + rule_comment + '"\n' yara_meta += '\tcomment = "' + rule_comment + '"\n'
yara_meta += '\tdate = "' + today + '"\n' yara_meta += '\tdate = "' + today + '"\n'
@@ -679,12 +678,13 @@ def convert_rules(rules, namespaces, cround):
# TODO: now the rule is finished and could be automatically checked with the capa-testfile(s) named in meta (doing it for all of them using yara-ci upload at the moment) # TODO: now the rule is finished and could be automatically checked with the capa-testfile(s) named in meta (doing it for all of them using yara-ci upload at the moment)
output_yar(yara) output_yar(yara)
converted_rules.append(rule_name) converted_rules.append(rule_name)
global count_incomplete
count_incomplete += incomplete count_incomplete += incomplete
else: else:
output_unsupported_capa_rules(rule.to_yaml(), rule.name, url, yara_condition) output_unsupported_capa_rules(rule.to_yaml(), rule.name, url, yara_condition)
pass pass
return count_incomplete
def main(argv=None): def main(argv=None):
if argv is None: if argv is None:
@@ -696,7 +696,6 @@ def main(argv=None):
capa.main.install_common_args(parser, wanted={"tag"}) capa.main.install_common_args(parser, wanted={"tag"})
args = parser.parse_args(args=argv) args = parser.parse_args(args=argv)
global make_priv
make_priv = args.private make_priv = args.private
if args.verbose: if args.verbose:
@@ -710,9 +709,9 @@ def main(argv=None):
logging.getLogger("capa2yara").setLevel(level) logging.getLogger("capa2yara").setLevel(level)
try: try:
rules = capa.main.get_rules([args.rules], disable_progress=True) rules_ = capa.main.get_rules([args.rules], disable_progress=True)
namespaces = capa.rules.index_rules_by_namespace(list(rules)) namespaces = capa.rules.index_rules_by_namespace(rules_)
rules = capa.rules.RuleSet(rules) rules = capa.rules.RuleSet(rules_)
logger.info("successfully loaded %s rules (including subscope rules which will be ignored)", len(rules)) logger.info("successfully loaded %s rules (including subscope rules which will be ignored)", len(rules))
if args.tag: if args.tag:
rules = rules.filter_rules_by_meta(args.tag) rules = rules.filter_rules_by_meta(args.tag)
@@ -745,14 +744,15 @@ def main(argv=None):
# do several rounds of converting rules because some rules for match: might not be converted in the 1st run # do several rounds of converting rules because some rules for match: might not be converted in the 1st run
num_rules = 9999999 num_rules = 9999999
cround = 0 cround = 0
count_incomplete = 0
while num_rules != len(converted_rules) or cround < min_rounds: while num_rules != len(converted_rules) or cround < min_rounds:
cround += 1 cround += 1
logger.info("doing convert_rules(), round: " + str(cround)) logger.info("doing convert_rules(), round: " + str(cround))
num_rules = len(converted_rules) num_rules = len(converted_rules)
convert_rules(rules, namespaces, cround) count_incomplete += convert_rules(rules, namespaces, cround, make_priv)
# one last round to collect all unconverted rules # one last round to collect all unconverted rules
convert_rules(rules, namespaces, 9000) count_incomplete += convert_rules(rules, namespaces, 9000, make_priv)
stats = "\n// converted rules : " + str(len(converted_rules)) stats = "\n// converted rules : " + str(len(converted_rules))
stats += "\n// among those are incomplete : " + str(count_incomplete) stats += "\n// among those are incomplete : " + str(count_incomplete)

View File

@@ -172,7 +172,7 @@ def capa_details(rules_path, file_path, output_format="dictionary"):
meta["analysis"].update(counts) meta["analysis"].update(counts)
meta["analysis"]["layout"] = capa.main.compute_layout(rules, extractor, capabilities) meta["analysis"]["layout"] = capa.main.compute_layout(rules, extractor, capabilities)
capa_output = False capa_output: Any = False
if output_format == "dictionary": if output_format == "dictionary":
# ...as python dictionary, simplified as textable but in dictionary # ...as python dictionary, simplified as textable but in dictionary
doc = rd.ResultDocument.from_capa(meta, rules, capabilities) doc = rd.ResultDocument.from_capa(meta, rules, capabilities)

View File

@@ -28,7 +28,7 @@ def main(argv=None):
if capa.helpers.is_runtime_ida(): if capa.helpers.is_runtime_ida():
from capa.ida.helpers import IDAIO from capa.ida.helpers import IDAIO
f: BinaryIO = IDAIO() f: BinaryIO = IDAIO() # type: ignore
else: else:
if argv is None: if argv is None:

View File

@@ -902,11 +902,15 @@ def redirecting_print_to_tqdm():
old_print(*args, **kwargs) old_print(*args, **kwargs)
try: try:
# Globaly replace print with new_print # Globaly replace print with new_print.
inspect.builtins.print = new_print # Verified this works manually on Python 3.11:
# >>> import inspect
# >>> inspect.builtins
# <module 'builtins' (built-in)>
inspect.builtins.print = new_print # type: ignore
yield yield
finally: finally:
inspect.builtins.print = old_print inspect.builtins.print = old_print # type: ignore
def lint(ctx: Context): def lint(ctx: Context):
@@ -998,10 +1002,8 @@ def main(argv=None):
time0 = time.time() time0 = time.time()
try: try:
rules = capa.main.get_rules(args.rules, disable_progress=True) rules = capa.rules.RuleSet(capa.main.get_rules(args.rules, disable_progress=True))
rule_count = len(rules) logger.info("successfully loaded %s rules", len(rules))
rules = capa.rules.RuleSet(rules)
logger.info("successfully loaded %s rules", rule_count)
if args.tag: if args.tag:
rules = rules.filter_rules_by_meta(args.tag) rules = rules.filter_rules_by_meta(args.tag)
logger.debug("selected %s rules", len(rules)) logger.debug("selected %s rules", len(rules))

View File

@@ -141,8 +141,7 @@ def main(argv=None):
return -1 return -1
try: try:
rules = capa.main.get_rules(args.rules) rules = capa.rules.RuleSet(capa.main.get_rules(args.rules))
rules = capa.rules.RuleSet(rules)
logger.info("successfully loaded %s rules", len(rules)) logger.info("successfully loaded %s rules", len(rules))
if args.tag: if args.tag:
rules = rules.filter_rules_by_meta(args.tag) rules = rules.filter_rules_by_meta(args.tag)

View File

@@ -136,7 +136,7 @@ def main(argv=None):
for feature, addr in extractor.extract_file_features(): for feature, addr in extractor.extract_file_features():
print("file: %s: %s" % (format_address(addr), feature)) print("file: %s: %s" % (format_address(addr), feature))
function_handles = extractor.get_functions() function_handles = tuple(extractor.get_functions())
if args.function: if args.function:
if args.format == "freeze": if args.format == "freeze":
@@ -173,7 +173,7 @@ def ida_main():
print("file: %s: %s" % (format_address(addr), feature)) print("file: %s: %s" % (format_address(addr), feature))
return return
function_handles = extractor.get_functions() function_handles = tuple(extractor.get_functions())
if function: if function:
function_handles = tuple(filter(lambda fh: fh.inner.start_ea == function, function_handles)) function_handles = tuple(filter(lambda fh: fh.inner.start_ea == function, function_handles))

View File

@@ -8,58 +8,63 @@
from capa.engine import * from capa.engine import *
from capa.features import * from capa.features import *
from capa.features.insn import * from capa.features.insn import *
import capa.features.address
ADDR1 = capa.features.address.AbsoluteVirtualAddress(0x401001)
ADDR2 = capa.features.address.AbsoluteVirtualAddress(0x401002)
ADDR3 = capa.features.address.AbsoluteVirtualAddress(0x401003)
ADDR4 = capa.features.address.AbsoluteVirtualAddress(0x401004)
def test_number(): def test_number():
assert Number(1).evaluate({Number(0): {1}}) == False assert Number(1).evaluate({Number(0): {ADDR1}}) == False
assert Number(1).evaluate({Number(1): {1}}) == True assert Number(1).evaluate({Number(1): {ADDR1}}) == True
assert Number(1).evaluate({Number(2): {1, 2}}) == False assert Number(1).evaluate({Number(2): {ADDR1, ADDR2}}) == False
def test_and(): def test_and():
assert And([Number(1)]).evaluate({Number(0): {1}}) == False assert And([Number(1)]).evaluate({Number(0): {ADDR1}}) == False
assert And([Number(1)]).evaluate({Number(1): {1}}) == True assert And([Number(1)]).evaluate({Number(1): {ADDR1}}) == True
assert And([Number(1), Number(2)]).evaluate({Number(0): {1}}) == False assert And([Number(1), Number(2)]).evaluate({Number(0): {ADDR1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(1): {1}}) == False assert And([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(2): {1}}) == False assert And([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == False
assert And([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {2}}) == True assert And([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR2}}) == True
def test_or(): def test_or():
assert Or([Number(1)]).evaluate({Number(0): {1}}) == False assert Or([Number(1)]).evaluate({Number(0): {ADDR1}}) == False
assert Or([Number(1)]).evaluate({Number(1): {1}}) == True assert Or([Number(1)]).evaluate({Number(1): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(0): {1}}) == False assert Or([Number(1), Number(2)]).evaluate({Number(0): {ADDR1}}) == False
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {2}}) == True assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR2}}) == True
def test_not(): def test_not():
assert Not(Number(1)).evaluate({Number(0): {1}}) == True assert Not(Number(1)).evaluate({Number(0): {ADDR1}}) == True
assert Not(Number(1)).evaluate({Number(1): {1}}) == False assert Not(Number(1)).evaluate({Number(1): {ADDR1}}) == False
def test_some(): def test_some():
assert Some(0, [Number(1)]).evaluate({Number(0): {1}}) == True assert Some(0, [Number(1)]).evaluate({Number(0): {ADDR1}}) == True
assert Some(1, [Number(1)]).evaluate({Number(0): {1}}) == False assert Some(1, [Number(1)]).evaluate({Number(0): {ADDR1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}}) == False assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}, Number(1): {1}}) == False assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}}) == False
assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}}) == True assert Some(2, [Number(1), Number(2), Number(3)]).evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}}) == True
assert ( assert (
Some(2, [Number(1), Number(2), Number(3)]).evaluate( Some(2, [Number(1), Number(2), Number(3)]).evaluate(
{Number(0): {1}, Number(1): {1}, Number(2): {1}, Number(3): {1}} {Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}, Number(3): {ADDR1}}
) )
== True == True
) )
assert ( assert (
Some(2, [Number(1), Number(2), Number(3)]).evaluate( Some(2, [Number(1), Number(2), Number(3)]).evaluate(
{ {
Number(0): {1}, Number(0): {ADDR1},
Number(1): {1}, Number(1): {ADDR1},
Number(2): {1}, Number(2): {ADDR1},
Number(3): {1}, Number(3): {ADDR1},
Number(4): {1}, Number(4): {ADDR1},
} }
) )
== True == True
@@ -69,10 +74,10 @@ def test_some():
def test_complex(): def test_complex():
assert True == Or( assert True == Or(
[And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5), Number(6)])])] [And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5), Number(6)])])]
).evaluate({Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}}) ).evaluate({Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}})
assert False == Or([And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5)])])]).evaluate( assert False == Or([And([Number(1), Number(2)]), Or([Number(3), Some(2, [Number(4), Number(5)])])]).evaluate(
{Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}} {Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}
) )
@@ -83,54 +88,54 @@ def test_range():
# unbounded range with matching feature should always match # unbounded range with matching feature should always match
assert Range(Number(1)).evaluate({Number(1): {}}) == True assert Range(Number(1)).evaluate({Number(1): {}}) == True
assert Range(Number(1)).evaluate({Number(1): {0}}) == True assert Range(Number(1)).evaluate({Number(1): {ADDR1}}) == True
# unbounded max # unbounded max
assert Range(Number(1), min=1).evaluate({Number(1): {0}}) == True assert Range(Number(1), min=1).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), min=2).evaluate({Number(1): {0}}) == False assert Range(Number(1), min=2).evaluate({Number(1): {ADDR1}}) == False
assert Range(Number(1), min=2).evaluate({Number(1): {0, 1}}) == True assert Range(Number(1), min=2).evaluate({Number(1): {ADDR1, ADDR2}}) == True
# unbounded min # unbounded min
assert Range(Number(1), max=0).evaluate({Number(1): {0}}) == False assert Range(Number(1), max=0).evaluate({Number(1): {ADDR1}}) == False
assert Range(Number(1), max=1).evaluate({Number(1): {0}}) == True assert Range(Number(1), max=1).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {0}}) == True assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {0, 1}}) == True assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2}}) == True
assert Range(Number(1), max=2).evaluate({Number(1): {0, 1, 3}}) == False assert Range(Number(1), max=2).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}}) == False
# we can do an exact match by setting min==max # we can do an exact match by setting min==max
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {}}) == False assert Range(Number(1), min=1, max=1).evaluate({Number(1): {}}) == False
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {1}}) == True assert Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), min=1, max=1).evaluate({Number(1): {1, 2}}) == False assert Range(Number(1), min=1, max=1).evaluate({Number(1): {ADDR1, ADDR2}}) == False
# bounded range # bounded range
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {}}) == False assert Range(Number(1), min=1, max=3).evaluate({Number(1): {}}) == False
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1}}) == True assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2}}) == True assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2, 3}}) == True assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3}}) == True
assert Range(Number(1), min=1, max=3).evaluate({Number(1): {1, 2, 3, 4}}) == False assert Range(Number(1), min=1, max=3).evaluate({Number(1): {ADDR1, ADDR2, ADDR3, ADDR4}}) == False
def test_short_circuit(): def test_short_circuit():
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True
# with short circuiting, only the children up until the first satisfied child are captured. # with short circuiting, only the children up until the first satisfied child are captured.
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=True).children) == 1 assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}, short_circuit=True).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}, short_circuit=False).children) == 2 assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}, short_circuit=False).children) == 2
def test_eval_order(): def test_eval_order():
# base cases. # base cases.
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}) == True assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}) == True
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}) == True assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}) == True
# with short circuiting, only the children up until the first satisfied child are captured. # with short circuiting, only the children up until the first satisfied child are captured.
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children) == 1 assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children) == 1
assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children) == 2 assert len(Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children) == 2
assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {1}, Number(2): {1}}).children) == 1 assert len(Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}, Number(2): {ADDR1}}).children) == 1
# and its guaranteed that children are evaluated in order. # and its guaranteed that children are evaluated in order.
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement == Number(1) assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children[0].statement == Number(1)
assert Or([Number(1), Number(2)]).evaluate({Number(1): {1}}).children[0].statement != Number(2) assert Or([Number(1), Number(2)]).evaluate({Number(1): {ADDR1}}).children[0].statement != Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement == Number(2) assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement == Number(2)
assert Or([Number(1), Number(2)]).evaluate({Number(2): {1}}).children[1].statement != Number(1) assert Or([Number(1), Number(2)]).evaluate({Number(2): {ADDR1}}).children[1].statement != Number(1)

View File

@@ -98,7 +98,7 @@ def test_rule_reformat_order():
def test_rule_reformat_meta_update(): def test_rule_reformat_meta_update():
# test updating the rule content after parsing # test updating the rule content after parsing
rule = textwrap.dedent( src = textwrap.dedent(
""" """
rule: rule:
meta: meta:
@@ -116,7 +116,7 @@ def test_rule_reformat_meta_update():
""" """
) )
rule = capa.rules.Rule.from_yaml(rule) rule = capa.rules.Rule.from_yaml(src)
rule.name = "test rule" rule.name = "test rule"
assert rule.to_yaml() == EXPECTED assert rule.to_yaml() == EXPECTED

View File

@@ -218,7 +218,7 @@ def test_match_matched_rules():
# the ordering of the rules must not matter, # the ordering of the rules must not matter,
# the engine should match rules in an appropriate order. # the engine should match rules in an appropriate order.
features, _ = match( features, _ = match(
capa.rules.topologically_order_rules(reversed(rules)), capa.rules.topologically_order_rules(list(reversed(rules))),
{capa.features.insn.Number(100): {1}}, {capa.features.insn.Number(100): {1}},
0x0, 0x0,
) )

View File

@@ -19,6 +19,7 @@ def test_optional_node_from_capa():
[], [],
) )
) )
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement) assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.OPTIONAL assert node.statement.type == rdoc.CompoundStatementType.OPTIONAL
@@ -32,6 +33,7 @@ def test_some_node_from_capa():
], ],
) )
) )
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.SomeStatement) assert isinstance(node.statement, rdoc.SomeStatement)
@@ -41,6 +43,7 @@ def test_range_node_from_capa():
capa.features.insn.Number(0), capa.features.insn.Number(0),
) )
) )
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.RangeStatement) assert isinstance(node.statement, rdoc.RangeStatement)
@@ -51,6 +54,7 @@ def test_subscope_node_from_capa():
capa.features.insn.Number(0), capa.features.insn.Number(0),
) )
) )
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.SubscopeStatement) assert isinstance(node.statement, rdoc.SubscopeStatement)
@@ -62,6 +66,7 @@ def test_and_node_from_capa():
], ],
) )
) )
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement) assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.AND assert node.statement.type == rdoc.CompoundStatementType.AND
@@ -74,6 +79,7 @@ def test_or_node_from_capa():
], ],
) )
) )
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement) assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.OR assert node.statement.type == rdoc.CompoundStatementType.OR
@@ -86,115 +92,138 @@ def test_not_node_from_capa():
], ],
) )
) )
assert isinstance(node, rdoc.StatementNode)
assert isinstance(node.statement, rdoc.CompoundStatement) assert isinstance(node.statement, rdoc.CompoundStatement)
assert node.statement.type == rdoc.CompoundStatementType.NOT assert node.statement.type == rdoc.CompoundStatementType.NOT
def test_os_node_from_capa(): def test_os_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.OS("")) node = rdoc.node_from_capa(capa.features.common.OS(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OSFeature) assert isinstance(node.feature, frzf.OSFeature)
def test_arch_node_from_capa(): def test_arch_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Arch("")) node = rdoc.node_from_capa(capa.features.common.Arch(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ArchFeature) assert isinstance(node.feature, frzf.ArchFeature)
def test_format_node_from_capa(): def test_format_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Format("")) node = rdoc.node_from_capa(capa.features.common.Format(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.FormatFeature) assert isinstance(node.feature, frzf.FormatFeature)
def test_match_node_from_capa(): def test_match_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.MatchedRule("")) node = rdoc.node_from_capa(capa.features.common.MatchedRule(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.MatchFeature) assert isinstance(node.feature, frzf.MatchFeature)
def test_characteristic_node_from_capa(): def test_characteristic_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Characteristic("")) node = rdoc.node_from_capa(capa.features.common.Characteristic(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.CharacteristicFeature) assert isinstance(node.feature, frzf.CharacteristicFeature)
def test_substring_node_from_capa(): def test_substring_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Substring("")) node = rdoc.node_from_capa(capa.features.common.Substring(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.SubstringFeature) assert isinstance(node.feature, frzf.SubstringFeature)
def test_regex_node_from_capa(): def test_regex_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Regex("")) node = rdoc.node_from_capa(capa.features.common.Regex(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.RegexFeature) assert isinstance(node.feature, frzf.RegexFeature)
def test_class_node_from_capa(): def test_class_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Class("")) node = rdoc.node_from_capa(capa.features.common.Class(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ClassFeature) assert isinstance(node.feature, frzf.ClassFeature)
def test_namespace_node_from_capa(): def test_namespace_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Namespace("")) node = rdoc.node_from_capa(capa.features.common.Namespace(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.NamespaceFeature) assert isinstance(node.feature, frzf.NamespaceFeature)
def test_bytes_node_from_capa(): def test_bytes_node_from_capa():
node = rdoc.node_from_capa(capa.features.common.Bytes(b"")) node = rdoc.node_from_capa(capa.features.common.Bytes(b""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.BytesFeature) assert isinstance(node.feature, frzf.BytesFeature)
def test_export_node_from_capa(): def test_export_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.Export("")) node = rdoc.node_from_capa(capa.features.file.Export(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ExportFeature) assert isinstance(node.feature, frzf.ExportFeature)
def test_import_node_from_capa(): def test_import_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.Import("")) node = rdoc.node_from_capa(capa.features.file.Import(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.ImportFeature) assert isinstance(node.feature, frzf.ImportFeature)
def test_section_node_from_capa(): def test_section_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.Section("")) node = rdoc.node_from_capa(capa.features.file.Section(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.SectionFeature) assert isinstance(node.feature, frzf.SectionFeature)
def test_function_name_node_from_capa(): def test_function_name_node_from_capa():
node = rdoc.node_from_capa(capa.features.file.FunctionName("")) node = rdoc.node_from_capa(capa.features.file.FunctionName(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.FunctionNameFeature) assert isinstance(node.feature, frzf.FunctionNameFeature)
def test_api_node_from_capa(): def test_api_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.API("")) node = rdoc.node_from_capa(capa.features.insn.API(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.APIFeature) assert isinstance(node.feature, frzf.APIFeature)
def test_property_node_from_capa(): def test_property_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Property("")) node = rdoc.node_from_capa(capa.features.insn.Property(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.PropertyFeature) assert isinstance(node.feature, frzf.PropertyFeature)
def test_number_node_from_capa(): def test_number_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Number(0)) node = rdoc.node_from_capa(capa.features.insn.Number(0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.NumberFeature) assert isinstance(node.feature, frzf.NumberFeature)
def test_offset_node_from_capa(): def test_offset_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Offset(0)) node = rdoc.node_from_capa(capa.features.insn.Offset(0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OffsetFeature) assert isinstance(node.feature, frzf.OffsetFeature)
def test_mnemonic_node_from_capa(): def test_mnemonic_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.Mnemonic("")) node = rdoc.node_from_capa(capa.features.insn.Mnemonic(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.MnemonicFeature) assert isinstance(node.feature, frzf.MnemonicFeature)
def test_operand_number_node_from_capa(): def test_operand_number_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.OperandNumber(0, 0)) node = rdoc.node_from_capa(capa.features.insn.OperandNumber(0, 0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OperandNumberFeature) assert isinstance(node.feature, frzf.OperandNumberFeature)
def test_operand_offset_node_from_capa(): def test_operand_offset_node_from_capa():
node = rdoc.node_from_capa(capa.features.insn.OperandOffset(0, 0)) node = rdoc.node_from_capa(capa.features.insn.OperandOffset(0, 0))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.OperandOffsetFeature) assert isinstance(node.feature, frzf.OperandOffsetFeature)
def test_basic_block_node_from_capa(): def test_basic_block_node_from_capa():
node = rdoc.node_from_capa(capa.features.basicblock.BasicBlock("")) node = rdoc.node_from_capa(capa.features.basicblock.BasicBlock(""))
assert isinstance(node, rdoc.FeatureNode)
assert isinstance(node.feature, frzf.BasicBlockFeature) assert isinstance(node.feature, frzf.BasicBlockFeature)

View File

@@ -13,8 +13,10 @@ import pytest
import capa.rules import capa.rules
import capa.engine import capa.engine
import capa.features.common import capa.features.common
from capa.features.address import AbsoluteVirtualAddress
from capa.features.file import FunctionName from capa.features.file import FunctionName
from capa.features.insn import Number, Offset, Property from capa.features.insn import Number, Offset, Property
from capa.engine import Or
from capa.features.common import ( from capa.features.common import (
OS, OS,
OS_LINUX, OS_LINUX,
@@ -29,12 +31,19 @@ from capa.features.common import (
Substring, Substring,
FeatureAccess, FeatureAccess,
) )
import capa.features.address
ADDR1 = capa.features.address.AbsoluteVirtualAddress(0x401001)
ADDR2 = capa.features.address.AbsoluteVirtualAddress(0x401002)
ADDR3 = capa.features.address.AbsoluteVirtualAddress(0x401003)
ADDR4 = capa.features.address.AbsoluteVirtualAddress(0x401004)
def test_rule_ctor(): def test_rule_ctor():
r = capa.rules.Rule("test rule", capa.rules.FUNCTION_SCOPE, Number(1), {}) r = capa.rules.Rule("test rule", capa.rules.FUNCTION_SCOPE, Or(Number(1)), {})
assert r.evaluate({Number(0): {1}}) == False assert r.evaluate({Number(0): {ADDR1}}) == False
assert r.evaluate({Number(1): {1}}) == True assert r.evaluate({Number(1): {ADDR2}}) == True
def test_rule_yaml(): def test_rule_yaml():
@@ -56,10 +65,10 @@ def test_rule_yaml():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(0): {1}}) == False assert r.evaluate({Number(0): {ADDR1}}) == False
assert r.evaluate({Number(0): {1}, Number(1): {1}}) == False assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}}) == False
assert r.evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}}) == True assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}}) == True
assert r.evaluate({Number(0): {1}, Number(1): {1}, Number(2): {1}, Number(3): {1}}) == True assert r.evaluate({Number(0): {ADDR1}, Number(1): {ADDR1}, Number(2): {ADDR1}, Number(3): {ADDR1}}) == True
def test_rule_yaml_complex(): def test_rule_yaml_complex():
@@ -82,8 +91,8 @@ def test_rule_yaml_complex():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(5): {1}, Number(6): {1}, Number(7): {1}, Number(8): {1}}) == True assert r.evaluate({Number(5): {ADDR1}, Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}) == True
assert r.evaluate({Number(6): {1}, Number(7): {1}, Number(8): {1}}) == False assert r.evaluate({Number(6): {ADDR1}, Number(7): {ADDR1}, Number(8): {ADDR1}}) == False
def test_rule_descriptions(): def test_rule_descriptions():
@@ -160,8 +169,8 @@ def test_rule_yaml_not():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(1): {1}}) == True assert r.evaluate({Number(1): {ADDR1}}) == True
assert r.evaluate({Number(1): {1}, Number(2): {1}}) == False assert r.evaluate({Number(1): {ADDR1}, Number(2): {ADDR1}}) == False
def test_rule_yaml_count(): def test_rule_yaml_count():
@@ -175,9 +184,9 @@ def test_rule_yaml_count():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(100): {}}) == False assert r.evaluate({Number(100): set()}) == False
assert r.evaluate({Number(100): {1}}) == True assert r.evaluate({Number(100): {ADDR1}}) == True
assert r.evaluate({Number(100): {1, 2}}) == False assert r.evaluate({Number(100): {ADDR1, ADDR2}}) == False
def test_rule_yaml_count_range(): def test_rule_yaml_count_range():
@@ -191,10 +200,10 @@ def test_rule_yaml_count_range():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(100): {}}) == False assert r.evaluate({Number(100): set()}) == False
assert r.evaluate({Number(100): {1}}) == True assert r.evaluate({Number(100): {ADDR1}}) == True
assert r.evaluate({Number(100): {1, 2}}) == True assert r.evaluate({Number(100): {ADDR1, ADDR2}}) == True
assert r.evaluate({Number(100): {1, 2, 3}}) == False assert r.evaluate({Number(100): {ADDR1, ADDR2, ADDR3}}) == False
def test_rule_yaml_count_string(): def test_rule_yaml_count_string():
@@ -208,10 +217,10 @@ def test_rule_yaml_count_string():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({String("foo"): {}}) == False assert r.evaluate({String("foo"): set()}) == False
assert r.evaluate({String("foo"): {1}}) == False assert r.evaluate({String("foo"): {ADDR1}}) == False
assert r.evaluate({String("foo"): {1, 2}}) == True assert r.evaluate({String("foo"): {ADDR1, ADDR2}}) == True
assert r.evaluate({String("foo"): {1, 2, 3}}) == False assert r.evaluate({String("foo"): {ADDR1, ADDR2, ADDR3}}) == False
def test_invalid_rule_feature(): def test_invalid_rule_feature():
@@ -481,11 +490,11 @@ def test_count_number_symbol():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Number(2): {}}) == False assert r.evaluate({Number(2): set()}) == False
assert r.evaluate({Number(2): {1}}) == True assert r.evaluate({Number(2): {ADDR1}}) == True
assert r.evaluate({Number(2): {1, 2}}) == False assert r.evaluate({Number(2): {ADDR1, ADDR2}}) == False
assert r.evaluate({Number(0x100, description="symbol name"): {1}}) == False assert r.evaluate({Number(0x100, description="symbol name"): {ADDR1}}) == False
assert r.evaluate({Number(0x100, description="symbol name"): {1, 2, 3}}) == True assert r.evaluate({Number(0x100, description="symbol name"): {ADDR1, ADDR2, ADDR3}}) == True
def test_invalid_number(): def test_invalid_number():
@@ -567,11 +576,11 @@ def test_count_offset_symbol():
""" """
) )
r = capa.rules.Rule.from_yaml(rule) r = capa.rules.Rule.from_yaml(rule)
assert r.evaluate({Offset(2): {}}) == False assert r.evaluate({Offset(2): set()}) == False
assert r.evaluate({Offset(2): {1}}) == True assert r.evaluate({Offset(2): {ADDR1}}) == True
assert r.evaluate({Offset(2): {1, 2}}) == False assert r.evaluate({Offset(2): {ADDR1, ADDR2}}) == False
assert r.evaluate({Offset(0x100, description="symbol name"): {1}}) == False assert r.evaluate({Offset(0x100, description="symbol name"): {ADDR1}}) == False
assert r.evaluate({Offset(0x100, description="symbol name"): {1, 2, 3}}) == True assert r.evaluate({Offset(0x100, description="symbol name"): {ADDR1, ADDR2, ADDR3}}) == True
def test_invalid_offset(): def test_invalid_offset():
@@ -966,10 +975,10 @@ def test_property_access():
""" """
) )
) )
assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.READ): {1}}) == True assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.READ): {ADDR1}}) == True
assert r.evaluate({Property("System.IO.FileInfo::Length"): {1}}) == False assert r.evaluate({Property("System.IO.FileInfo::Length"): {ADDR1}}) == False
assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.WRITE): {1}}) == False assert r.evaluate({Property("System.IO.FileInfo::Length", access=FeatureAccess.WRITE): {ADDR1}}) == False
def test_property_access_symbol(): def test_property_access_symbol():
@@ -986,7 +995,7 @@ def test_property_access_symbol():
) )
assert ( assert (
r.evaluate( r.evaluate(
{Property("System.IO.FileInfo::Length", access=FeatureAccess.READ, description="some property"): {1}} {Property("System.IO.FileInfo::Length", access=FeatureAccess.READ, description="some property"): {ADDR1}}
) )
== True == True
) )