fix: imports and add tests

This commit is contained in:
Moritz Raabe
2022-04-06 17:29:12 +02:00
parent b5be876e61
commit 97e76a88e3
5 changed files with 37 additions and 8 deletions

View File

@@ -2,6 +2,7 @@ import logging
from typing import Tuple, Iterator
import dnfile
import pefile
from capa.features.common import OS, OS_ANY, ARCH_ANY, ARCH_I386, ARCH_AMD64, FORMAT_DOTNET, Arch, Format, Feature
from capa.features.extractors.base_extractor import FeatureExtractor
@@ -18,11 +19,11 @@ def extract_file_os(**kwargs):
def extract_file_arch(pe, **kwargs):
# TODO differences for versions < 4.5?
# via https://stackoverflow.com/a/23614024/10548020
if pe.net.Flags.CLR_32BITREQUIRED and pe.net.Flags.CLR_PREFER_32BIT:
# to distinguish in more detail, see https://stackoverflow.com/a/23614024/10548020
# .NET 4.5 added option: any CPU, 32-bit preferred
if pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE:
yield Arch(ARCH_I386), 0x0
elif not pe.net.Flags.CLR_32BITREQUIRED and not pe.net.Flags.CLR_PREFER_32BIT:
elif not pe.net.Flags.CLR_32BITREQUIRED and pe.PE_TYPE == pefile.OPTIONAL_HEADER_MAGIC_PE_PLUS:
yield Arch(ARCH_AMD64), 0x0
else:
yield Arch(ARCH_ANY), 0x0
@@ -63,6 +64,9 @@ class DnfileFeatureExtractor(FeatureExtractor):
self.pe: dnfile.dnPE = dnfile.dnPE(path)
def get_base_address(self) -> int:
return 0x0
def get_entry_point(self) -> int:
return self.pe.net.struct.EntryPointTokenOrRva
def extract_global_features(self):
@@ -78,7 +82,7 @@ class DnfileFeatureExtractor(FeatureExtractor):
return self.pe.net.struct.MajorRuntimeVersion, self.pe.net.struct.MinorRuntimeVersion
def get_meta_version_string(self) -> str:
return self.pe.net.metadata.struct.Version.decode("utf-8")
return self.pe.net.metadata.struct.Version.rstrip(b"\x00").decode("utf-8")
def get_functions(self):
raise NotImplementedError("DnfileFeatureExtractor can only be used to extract file features")

View File

@@ -59,6 +59,7 @@ import capa.features.insn
import capa.features.common
import capa.features.basicblock
import capa.features.extractors.base_extractor
from capa.features.common import Feature
logger = logging.getLogger(__name__)

View File

@@ -63,7 +63,7 @@ from capa.features.common import (
FORMAT_DOTNET,
FORMAT_FREEZE,
)
from capa.features.extractors.base_extractor import FunctionHandle, FeatureExtractor
from capa.features.extractors.base_extractor import BBHandle, InsnHandle, FunctionHandle, FeatureExtractor
RULES_PATH_DEFAULT_STRING = "(embedded rules)"
SIGNATURES_PATH_DEFAULT_STRING = "(embedded signatures)"

View File

@@ -24,6 +24,7 @@ import capa.features.common
import capa.features.basicblock
from capa.features.common import (
OS,
OS_ANY,
OS_LINUX,
ARCH_I386,
FORMAT_PE,
@@ -32,6 +33,7 @@ from capa.features.common import (
OS_WINDOWS,
BITNESS_X32,
BITNESS_X64,
FORMAT_DOTNET,
Arch,
Format,
)
@@ -134,6 +136,12 @@ def get_pefile_extractor(path):
return capa.features.extractors.pefile.PefileFeatureExtractor(path)
def get_dnfile_extractor(path):
import capa.features.extractors.dnfile_
return capa.features.extractors.dnfile_.DnfileFeatureExtractor(path)
def extract_global_features(extractor):
features = collections.defaultdict(set)
for feature, va in extractor.extract_global_features():
@@ -591,6 +599,8 @@ FEATURE_PRESENCE_TESTS_DOTNET = sorted(
[
("b9f5b", "file", Arch(ARCH_I386), True),
("b9f5b", "file", Arch(ARCH_AMD64), False),
("b9f5b", "file", OS(OS_ANY), True),
("b9f5b", "file", Format(FORMAT_DOTNET), True),
],
# order tests by (file, item)
# so that our LRU cache is most effective.
@@ -713,4 +723,4 @@ def pingtaest_extractor():
@pytest.fixture
def b9f5b_extractor():
return get_extractor(get_data_path_by_name("b9f5b"))
return get_dnfile_extractor(get_data_path_by_name("b9f5b"))

View File

@@ -22,4 +22,18 @@ import capa.features.file
indirect=["sample", "scope"],
)
def test_dnfile_features(sample, scope, feature, expected):
fixtures.do_test_feature_presence(fixtures.get_pefile_extractor, sample, scope, feature, expected)
fixtures.do_test_feature_presence(fixtures.get_dnfile_extractor, sample, scope, feature, expected)
@parametrize(
"function,expected",
[
("is_dotnet_file", True),
("get_entry_point", 0x6000007),
("get_runtime_version", (2, 5)),
("get_meta_version_string", "v2.0.50727"),
],
)
def test_dnfile_extractor(b9f5b_extractor, function, expected):
func = getattr(b9f5b_extractor, function)
assert func() == expected