tests: add snapshot tests for feature extraction (#3069)

Introduces data-driven snapshot tests that regenerate capa freeze files
for a curated set of samples in the tests/data submodule and compare the
bytes against committed fixtures under tests/fixtures/freezes/. Any
change that perturbs feature extraction surfaces as a test failure with
a feature-count delta and a truncated unified diff.
This commit is contained in:
Willi Ballenthin
2026-06-09 23:28:49 +02:00
committed by GitHub
parent 58bfa7607e
commit ccf3a87e83
7 changed files with 357 additions and 19 deletions
+2
View File
@@ -3,6 +3,7 @@
## master (unreleased)
### New Features
- freeze: add `--reproducible` flag that zeros dynamic header metadata
### Breaking Changes
@@ -131,6 +132,7 @@
- ci: use explicit and per job permissions @mike-hunhoff #3002
- replace black/isort/flake8 with ruff @mike-hunhoff #2992
- ci: update GitHub Actions to support Node.js 24 (deprecate Node.js 20) @mr-tz #2984
- tests: add snapshot tests for feature extraction @williballenthin #3069
### Raw diffs
- [capa v9.4.0...master](https://github.com/mandiant/capa/compare/v9.4.0...master)
+1 -1
View File
@@ -98,7 +98,7 @@ def extract_file_namespace_features(pe: dnfile.dnPE, **kwargs) -> Iterator[tuple
# namespaces may be empty, discard
namespaces.discard("")
for namespace in namespaces:
for namespace in sorted(namespaces):
# namespace do not have an associated token, so we yield 0x0
yield Namespace(namespace), NO_ADDRESS
+77 -16
View File
@@ -92,10 +92,7 @@ class Address(HashableModel):
return cls(type=AddressType.THREAD, value=(a.process.ppid, a.process.pid, a.tid))
elif isinstance(a, capa.features.address.DynamicCallAddress):
return cls(
type=AddressType.CALL,
value=(a.thread.process.ppid, a.thread.process.pid, a.thread.tid, a.id),
)
return cls(type=AddressType.CALL, value=(a.thread.process.ppid, a.thread.process.pid, a.thread.tid, a.id))
elif a == capa.features.address.NO_ADDRESS or isinstance(a, capa.features.address._NoAddress):
return cls(type=AddressType.NO_ADDRESS, value=None)
@@ -346,9 +343,14 @@ class Freeze(BaseModel):
model_config = ConfigDict(populate_by_name=True)
def dumps_static(extractor: StaticFeatureExtractor) -> str:
def dumps_static(extractor: StaticFeatureExtractor, reproducible: bool = False) -> str:
"""
serialize the given extractor to a string
When `reproducible` is true, the freeze's dynamic header metadata (e.g. the
embedded capa version) is zeroed out so that output is identical across
capa versions for a given extractor. This is used by the feature snapshot
tests to keep fixtures stable across version bumps.
"""
global_features: list[GlobalFeature] = []
for feature, _ in extractor.extract_global_features():
@@ -357,6 +359,7 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
feature=feature_from_capa(feature),
)
)
global_features.sort(key=lambda gf: gf.feature.model_dump_json())
file_features: list[FileFeature] = []
for feature, address in extractor.extract_file_features():
@@ -366,6 +369,7 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
address=Address.from_capa(address),
)
)
file_features.sort(key=lambda ff: (ff.address, ff.feature.model_dump_json()))
function_features: list[FunctionFeatures] = []
for f in extractor.get_functions():
@@ -378,6 +382,7 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
)
for feature, addr in extractor.extract_function_features(f)
]
ffeatures.sort(key=lambda ff: (ff.address, ff.feature.model_dump_json()))
basic_blocks = []
for bb in extractor.get_basic_blocks(f):
@@ -390,6 +395,7 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
)
for feature, addr in extractor.extract_basic_block_features(f, bb)
]
bbfeatures.sort(key=lambda bf: (bf.address, bf.feature.model_dump_json()))
instructions = []
for insn in extractor.get_instructions(f, bb):
@@ -402,6 +408,7 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
)
for feature, addr in extractor.extract_insn_features(f, bb, insn)
]
ifeatures.sort(key=lambda i: (i.address, i.feature.model_dump_json()))
instructions.append(
InstructionFeatures(
@@ -410,6 +417,7 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
)
)
instructions.sort(key=lambda i: i.address)
basic_blocks.append(
BasicBlockFeatures(
address=bbaddr,
@@ -418,6 +426,7 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
)
)
basic_blocks.sort(key=lambda bb: bb.address)
function_features.append(
FunctionFeatures(
address=faddr,
@@ -426,18 +435,21 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
)
)
function_features.sort(key=lambda ff: ff.address)
features = StaticFeatures(
global_=global_features, # type: ignore[call-arg] # pydantic alias "global" not recognized by type checkers
file=tuple(file_features),
functions=tuple(function_features),
)
extractor_version = "" if reproducible else capa.version.__version__
freeze = Freeze(
version=CURRENT_VERSION,
base_address=Address.from_capa(extractor.get_base_address()), # type: ignore[call-arg] # pydantic alias "base address" not recognized by type checkers
sample_hashes=extractor.get_sample_hashes(),
flavor="static",
extractor=Extractor(name=extractor.__class__.__name__),
extractor=Extractor(name=extractor.__class__.__name__, version=extractor_version),
features=features,
)
# type checkers are unable to recognise `base_address` as an argument due to alias
@@ -445,9 +457,11 @@ def dumps_static(extractor: StaticFeatureExtractor) -> str:
return freeze.model_dump_json()
def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
def dumps_dynamic(extractor: DynamicFeatureExtractor, reproducible: bool = False) -> str:
"""
serialize the given extractor to a string
See `dumps_static` for `reproducible`.
"""
global_features: list[GlobalFeature] = []
for feature, _ in extractor.extract_global_features():
@@ -456,6 +470,7 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
feature=feature_from_capa(feature),
)
)
global_features.sort(key=lambda gf: gf.feature.model_dump_json())
file_features: list[FileFeature] = []
for feature, address in extractor.extract_file_features():
@@ -465,6 +480,7 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
address=Address.from_capa(address),
)
)
file_features.sort(key=lambda ff: (ff.address, ff.feature.model_dump_json()))
process_features: list[ProcessFeatures] = []
for p in extractor.get_processes():
@@ -478,6 +494,7 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
)
for feature, addr in extractor.extract_process_features(p)
]
pfeatures.sort(key=lambda pf: (pf.address, pf.feature.model_dump_json()))
threads = []
for t in extractor.get_threads(p):
@@ -490,6 +507,7 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
)
for feature, addr in extractor.extract_thread_features(p, t)
]
tfeatures.sort(key=lambda tf: (tf.address, tf.feature.model_dump_json()))
calls = []
for call in extractor.get_calls(p, t):
@@ -503,6 +521,7 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
)
for feature, addr in extractor.extract_call_features(p, t, call)
]
cfeatures.sort(key=lambda cf: (cf.address, cf.feature.model_dump_json()))
calls.append(
CallFeatures(
@@ -512,6 +531,7 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
)
)
calls.sort(key=lambda c: c.address)
threads.append(
ThreadFeatures(
address=taddr,
@@ -520,6 +540,7 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
)
)
threads.sort(key=lambda t: t.address)
process_features.append(
ProcessFeatures(
address=paddr,
@@ -529,6 +550,8 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
)
)
process_features.sort(key=lambda pf: pf.address)
features = DynamicFeatures(
global_=global_features, # type: ignore[call-arg] # pydantic alias "global" not recognized by type checkers
file=tuple(file_features),
@@ -539,12 +562,13 @@ def dumps_dynamic(extractor: DynamicFeatureExtractor) -> str:
get_base_addr = getattr(extractor, "get_base_address", None)
base_addr = get_base_addr() if get_base_addr else capa.features.address.NO_ADDRESS
extractor_version = "" if reproducible else capa.version.__version__
freeze = Freeze(
version=CURRENT_VERSION,
base_address=Address.from_capa(base_addr), # type: ignore[call-arg] # pydantic alias "base address" not recognized by type checkers
sample_hashes=extractor.get_sample_hashes(),
flavor="dynamic",
extractor=Extractor(name=extractor.__class__.__name__),
extractor=Extractor(name=extractor.__class__.__name__, version=extractor_version),
features=features,
)
# type checkers are unable to recognise `base_address` as an argument due to alias
@@ -627,28 +651,28 @@ def loads_dynamic(s: str) -> DynamicFeatureExtractor:
MAGIC = "capa0000".encode("ascii")
def dumps(extractor: FeatureExtractor) -> str:
def dumps(extractor: FeatureExtractor, reproducible: bool = False) -> str:
"""serialize the given extractor to a string."""
if isinstance(extractor, StaticFeatureExtractor):
doc = dumps_static(extractor)
doc = dumps_static(extractor, reproducible=reproducible)
elif isinstance(extractor, DynamicFeatureExtractor):
doc = dumps_dynamic(extractor)
doc = dumps_dynamic(extractor, reproducible=reproducible)
else:
raise ValueError("Invalid feature extractor")
return doc
def dump(extractor: FeatureExtractor) -> bytes:
def dump(extractor: FeatureExtractor, reproducible: bool = False) -> bytes:
"""serialize the given extractor to a byte array."""
return MAGIC + zlib.compress(dumps(extractor).encode("utf-8"))
return MAGIC + zlib.compress(dumps(extractor, reproducible=reproducible).encode("utf-8"))
def is_freeze(buf: bytes) -> bool:
return buf[: len(MAGIC)] == MAGIC
def loads(s: str):
def loads(s: str) -> FeatureExtractor:
doc = json.loads(s)
if doc["version"] != CURRENT_VERSION:
@@ -662,7 +686,7 @@ def loads(s: str):
raise ValueError(f"unsupported freeze format flavor: {doc['flavor']}")
def load(buf: bytes):
def load(buf: bytes) -> FeatureExtractor:
"""deserialize a set of features (as a NullFeatureExtractor) from a byte array."""
if not is_freeze(buf):
raise ValueError("missing magic header")
@@ -685,6 +709,11 @@ def main(argv=None):
parser = argparse.ArgumentParser(description="save capa features to a file")
capa.main.install_common_args(parser, {"input_file", "format", "backend", "os", "signatures"})
parser.add_argument("output", type=str, help="Path to output file")
parser.add_argument(
"--reproducible",
action="store_true",
help="zero out dynamic header metadata (e.g. capa version) so output is stable across capa versions",
)
args = parser.parse_args(args=argv)
try:
@@ -696,11 +725,43 @@ def main(argv=None):
except capa.main.ShouldExitError as e:
return e.status_code
Path(args.output).write_bytes(dump(extractor))
output_path = Path(args.output)
output_path.write_bytes(dump(extractor, reproducible=args.reproducible))
# Log a manifest entry for the feature snapshot tests at INFO level. This
# makes it easy to copy/paste into
# `tests/fixtures/snapshots/features/manifest.json` when adding a new
# fixture or refreshing an existing one.
entry: dict[str, str] = {
"name": output_path.stem,
"sample": str(args.input_file),
"freeze": output_path.name,
}
if args.format and args.format != "auto":
entry["format"] = args.format
if args.backend and args.backend != "auto":
entry["backend"] = args.backend
if args.os and args.os != "auto":
entry["os"] = args.os
commit = _git_head_commit()
if commit:
entry["generated_at_commit"] = commit
logger.info("manifest entry: %s", json.dumps(entry))
return 0
def _git_head_commit() -> str:
"""Return the HEAD commit, or empty string if this isn't a git checkout."""
import subprocess
try:
out = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL)
except (subprocess.CalledProcessError, FileNotFoundError, OSError):
return ""
return out.decode("ascii", errors="replace").strip()
if __name__ == "__main__":
import sys
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from capa.features.freeze import main
sys.exit(main())
+256
View File
@@ -0,0 +1,256 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data-driven feature snapshot tests.
For every entry in `tests/fixtures/snapshots/features/manifest.json`, this
module regenerates a capa freeze from the corresponding sample via
`capa.features.freeze.main --reproducible`, compares it byte-for-byte
against the committed `.frz` file, and on mismatch renders a unified diff
of the freeze contents so a reviewer can see which features appeared,
disappeared, or moved.
A failing test means capa now extracts different features from the same
sample than it used to. That can be intentional (you changed an extractor)
or accidental (an unrelated change perturbed extraction); see the failure
message for how to update the fixture or investigate.
Refreshing a fixture after an intentional change::
python -m capa.features.freeze --reproducible \\
tests/data/<sample> tests/fixtures/snapshots/features/<name>.frz
The manifest is edited by hand when samples are added or removed.
"""
from __future__ import annotations
import json
import zlib
import difflib
import tempfile
from typing import Any, Optional
from pathlib import Path
import pytest
from pydantic import BaseModel, ConfigDict
import capa.features.freeze
TESTS_DIR = Path(__file__).resolve().parent
TESTS_DATA_DIR = TESTS_DIR / "data"
FEATURE_SNAPSHOTS_DIR = TESTS_DATA_DIR / "fixtures" / "snapshots" / "features"
MANIFEST_PATH = FEATURE_SNAPSHOTS_DIR / "manifest.json"
class FeatureSnapshot(BaseModel):
"""One entry in the feature snapshot manifest."""
model_config = ConfigDict(frozen=True)
name: str
sample: str
freeze: str
explanation: str = ""
# Git commit at which this fixture was last regenerated. Purely informational:
# on test failure we surface it so a reviewer can run `git log <commit>..HEAD`
# to see what's changed since. Not validated — humans keep it accurate.
generated_at_commit: Optional[str] = None
format: Optional[str] = None
backend: Optional[str] = None
os: Optional[str] = None
@property
def sample_path(self) -> Path:
return TESTS_DATA_DIR / self.sample
@property
def freeze_path(self) -> Path:
return FEATURE_SNAPSHOTS_DIR / self.freeze
class Manifest(BaseModel):
version: int = 1
description: str = ""
snapshots: list[FeatureSnapshot]
@classmethod
def from_file(cls, path: Path = MANIFEST_PATH) -> Manifest:
return cls.model_validate_json(path.read_text(encoding="utf-8"))
_SNAPSHOTS = Manifest.from_file().snapshots
def _ids(snapshots: list[FeatureSnapshot]) -> list[str]:
return [s.name for s in snapshots]
def _regenerate(snapshot: FeatureSnapshot) -> bytes:
"""Run the freeze CLI against the sample and return the produced bytes."""
import logging
root = logging.getLogger()
handlers_before = list(root.handlers)
with tempfile.TemporaryDirectory() as tmp:
out_path = Path(tmp) / "out.frz"
argv = [str(snapshot.sample_path), str(out_path), "--reproducible"]
if snapshot.format is not None:
argv += ["--format", snapshot.format]
if snapshot.backend is not None:
argv += ["--backend", snapshot.backend]
if snapshot.os is not None:
argv += ["--os", snapshot.os]
rc = capa.features.freeze.main(argv)
# capa.main.handle_common_args() unconditionally appends a RichHandler
# to the root logger on every call. Since we call freeze.main() once per
# snapshot, handlers accumulate and duplicate every log line. Remove
# whatever was added so the next iteration starts clean.
for h in root.handlers[:]:
if h not in handlers_before:
root.removeHandler(h)
if rc != 0:
raise RuntimeError(f"capa.features.freeze.main exited with status {rc}")
return out_path.read_bytes()
def _doc_to_lines(doc: dict[str, Any]) -> list[str]:
"""
Render a freeze JSON document to a list of lines suitable for unified-diffing.
We pretty-print with sorted keys so that field reordering (which is
meaningful for features) is preserved while key ordering within objects is
normalized.
"""
return json.dumps(doc, indent=2, sort_keys=True).splitlines(keepends=True)
def _load_freeze_doc(buf: bytes) -> dict[str, Any]:
"""deserialize bytes to capa.features.freeze.Freeze, as JSON-like object.
capa.features.freeze.loads() deserializes into a FeatureExtractor, not Freeze (or JSON, which we need for diffing).
"""
magic = capa.features.freeze.MAGIC
assert buf[: len(magic)] == magic, "missing freeze magic header"
return json.loads(zlib.decompress(buf[len(magic) :]).decode("utf-8"))
def _format_mismatch(snapshot: FeatureSnapshot, expected: bytes, actual: bytes) -> str:
"""Build a failure message describing how the freezes differ."""
lines = [
f"feature snapshot drift for {snapshot.name!r}:",
f" sample: {snapshot.sample}",
f" expected freeze: {snapshot.freeze_path}",
" actual freeze: <regenerated>",
f" expected size: {len(expected):,} bytes",
f" actual size: {len(actual):,} bytes",
]
if snapshot.generated_at_commit:
lines.append(f" last regenerated at: {snapshot.generated_at_commit}")
expected_doc = _load_freeze_doc(expected)
actual_doc = _load_freeze_doc(actual)
expected_lines = _doc_to_lines(expected_doc)
actual_lines = _doc_to_lines(actual_doc)
# difflib.unified_diff uses SequenceMatcher which is O(n^2) for dissimilar
# sequences. Large freeze documents (e.g. mimikatz) expand to millions of
# JSON lines, making a naive diff take hours. Skip it when the input is too
# large — the regeneration command below is the intended way to inspect.
MAX_DIFFABLE_LINES = 100_000
MAX_DIFF_LINES = 200
total_lines = len(expected_lines) + len(actual_lines)
lines.append("")
if total_lines > MAX_DIFFABLE_LINES:
lines.append(
f"diff skipped: documents too large ({len(expected_lines):,} + {len(actual_lines):,} lines)."
" Regenerate the fixture locally to inspect."
)
else:
diff = list(
difflib.unified_diff(
expected_lines,
actual_lines,
fromfile=f"expected/{snapshot.freeze}",
tofile=f"actual/{snapshot.freeze}",
n=2,
)
)
if len(diff) > MAX_DIFF_LINES:
lines.append(f"unified diff ({len(diff)} lines, truncated to {MAX_DIFF_LINES}):")
diff = diff[:MAX_DIFF_LINES]
else:
lines.append(f"unified diff ({len(diff)} lines):")
lines.extend(line.rstrip("\n") for line in diff)
lines.append("")
lines.append("how and when to update this snapshot:")
lines.append(" If this change to feature extraction is INTENTIONAL (you edited an extractor):")
lines.append(" 1. regenerate the fixture:")
lines.append(
f" python -m capa.features.freeze --reproducible \\\n"
f" {snapshot.sample_path} {snapshot.freeze_path}"
)
lines.append(
" 2. update `generated_at_commit` in manifest.json to HEAD (the freeze CLI emits a suggested entry at INFO)."
)
lines.append(" If it is ACCIDENTAL (extraction shifted as a side effect of an unrelated change),")
lines.append(" do NOT update the fixture; fix the root cause instead.")
if snapshot.generated_at_commit:
lines.append(
f" To see what's changed since this fixture was last regenerated:\n"
f" git log {snapshot.generated_at_commit}..HEAD -- capa/"
)
return "\n".join(lines)
_BACKEND_AVAILABLE: dict[str, bool] = {}
def _is_backend_available(backend: str) -> bool:
if backend not in _BACKEND_AVAILABLE:
if backend == "ida":
try:
import idapro # noqa: F401
_BACKEND_AVAILABLE[backend] = True
except ImportError:
_BACKEND_AVAILABLE[backend] = False
else:
_BACKEND_AVAILABLE[backend] = True
return _BACKEND_AVAILABLE[backend]
@pytest.mark.parametrize("snapshot", _SNAPSHOTS, ids=_ids(_SNAPSHOTS))
def test_feature_snapshot(snapshot: FeatureSnapshot):
"""
Regenerate the freeze for `snapshot.sample` and assert it matches
`snapshot.freeze` byte-for-byte.
"""
if snapshot.backend and not _is_backend_available(snapshot.backend):
pytest.skip(f"{snapshot.backend} backend not available")
expected = snapshot.freeze_path.read_bytes()
actual = _regenerate(snapshot)
if actual == expected:
return
pytest.fail(_format_mismatch(snapshot, expected, actual))
+1 -1
View File
@@ -122,7 +122,7 @@ def test_null_feature_extractor():
def compare_extractors(a: DynamicFeatureExtractor, b: DynamicFeatureExtractor):
assert list(a.extract_file_features()) == list(b.extract_file_features())
assert sorted(set(a.extract_file_features())) == sorted(set(b.extract_file_features()))
assert addresses(a.get_processes()) == addresses(b.get_processes())
for p in a.get_processes():
+1 -1
View File
@@ -129,7 +129,7 @@ def test_null_feature_extractor():
def compare_extractors(a, b):
assert list(a.extract_file_features()) == list(b.extract_file_features())
assert sorted(set(a.extract_file_features())) == sorted(set(b.extract_file_features()))
assert addresses(a.get_functions()) == addresses(b.get_functions())
for f in a.get_functions():