fix: dedent bulk-process.py main() body so explicit argv is used

The entire main() body was indented inside `if argv is None:`, causing
main() to silently return None when called with an explicit argv list.

Closes SURF-90.
This commit is contained in:
Willi Ballenthin
2026-04-22 22:23:57 +03:00
committed by Willi Ballenthin
parent a938c87fa4
commit 7d8714098c
3 changed files with 79 additions and 50 deletions
+1
View File
@@ -49,6 +49,7 @@
- fix: Scopes.from_dict uses cls instead of self so subclasses return the correct type @williballenthin
- fix: correct wrong dict key in VMRay _compute_monitor_threads assertion (used thread_id instead of process_id) @williballenthin
- fix: replace assert with isinstance guard in get_callee for invalid MethodSpec tokens @williballenthin
- fix: dedent bulk-process.py main() body so explicit argv argument is used instead of silently ignored @williballenthin (SURF-90)
- fix: guard statistics.quantiles/mean in compare-backends.py report() against empty duration lists @williballenthin (SURF-89)
- fix: replace zipfile with pyzipper in minimize_vmray_results.py so output archive is AES-encrypted @williballenthin (SURF-88)
- fix: assign yara_strings/yara_condition to empty string when Some has cmin=0 to prevent UnboundLocalError @williballenthin (SURF-87)
+58 -50
View File
@@ -160,69 +160,77 @@ def main(argv=None):
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="detect capabilities in programs.")
capa.main.install_common_args(parser, wanted={"rules", "signatures", "format", "os", "backend"})
parser.add_argument("input_directory", type=str, help="Path to directory of files to recursively analyze")
parser.add_argument(
"-n", "--parallelism", type=int, default=multiprocessing.cpu_count(), help="parallelism factor"
)
parser.add_argument("--no-mp", action="store_true", help="disable subprocesses")
args = parser.parse_args(args=argv)
parser = argparse.ArgumentParser(description="detect capabilities in programs.")
capa.main.install_common_args(parser, wanted={"rules", "signatures", "format", "os", "backend"})
parser.add_argument("input_directory", type=str, help="Path to directory of files to recursively analyze")
parser.add_argument("-n", "--parallelism", type=int, default=multiprocessing.cpu_count(), help="parallelism factor")
parser.add_argument("--no-mp", action="store_true", help="disable subprocesses")
args = parser.parse_args(args=argv)
samples = []
for file in Path(args.input_directory).rglob("*"):
samples.append(file)
samples = []
for file in Path(args.input_directory).rglob("*"):
samples.append(file)
cpu_count = multiprocessing.cpu_count()
cpu_count = multiprocessing.cpu_count()
def pmap(f, args, parallelism=cpu_count):
"""apply the given function f to the given args using subprocesses"""
return multiprocessing.Pool(parallelism).imap(f, args)
def pmap(f, args, parallelism=cpu_count):
"""apply the given function f to the given args using subprocesses"""
return multiprocessing.Pool(parallelism).imap(f, args)
def tmap(f, args, parallelism=cpu_count):
"""apply the given function f to the given args using threads"""
return multiprocessing.pool.ThreadPool(parallelism).imap(f, args)
def tmap(f, args, parallelism=cpu_count):
"""apply the given function f to the given args using threads"""
return multiprocessing.pool.ThreadPool(parallelism).imap(f, args)
def map(f, args, parallelism=None):
"""apply the given function f to the given args in the current thread"""
for arg in args:
yield f(arg)
def map(f, args, parallelism=None):
"""apply the given function f to the given args in the current thread"""
for arg in args:
yield f(arg)
if args.no_mp:
if args.parallelism == 1:
logger.debug("using current thread mapper")
mapper = map
else:
logger.debug("using threading mapper")
mapper = tmap
if args.no_mp:
if args.parallelism == 1:
logger.debug("using current thread mapper")
mapper = map
else:
logger.debug("using process mapper")
mapper = pmap
logger.debug("using threading mapper")
mapper = tmap
else:
logger.debug("using process mapper")
mapper = pmap
rules = args.rules
if rules == [capa.main.RULES_PATH_DEFAULT_STRING]:
rules = None
rules = args.rules
if rules == [capa.main.RULES_PATH_DEFAULT_STRING]:
rules = None
results = {}
for result in mapper(
get_capa_results,
[(rules, args.signatures, args.format, args.backend, args.os, str(sample)) for sample in samples],
parallelism=args.parallelism,
):
if result["status"] == "error":
logger.warning(result["error"])
elif result["status"] == "ok":
doc = rd.ResultDocument.model_validate(result["ok"]).model_dump_json(exclude_none=True)
results[result["path"]] = json.loads(doc)
results = {}
for result in mapper(
get_capa_results,
[
(
rules,
args.signatures,
args.format,
args.backend,
args.os,
str(sample),
)
for sample in samples
],
parallelism=args.parallelism,
):
if result["status"] == "error":
logger.warning(result["error"])
elif result["status"] == "ok":
doc = rd.ResultDocument.model_validate(result["ok"]).model_dump_json(exclude_none=True)
results[result["path"]] = json.loads(doc)
else:
raise ValueError(f"unexpected status: {result['status']}")
else:
raise ValueError(f"unexpected status: {result['status']}")
print(json.dumps(results))
print(json.dumps(results))
logger.info("done.")
logger.info("done.")
return 0
return 0
if __name__ == "__main__":
+20
View File
@@ -119,6 +119,26 @@ def test_bulk_process(tmp_path):
assert p.returncode == 0
def test_bulk_process_explicit_argv(tmp_path):
import importlib.util
t = tmp_path / "test"
t.mkdir()
source_file = Path(__file__).resolve().parent / "data" / "ping_täst.exe_"
dest_file = t / "test.exe_"
dest_file.write_bytes(source_file.read_bytes())
spec = importlib.util.spec_from_file_location("bulk_process", get_script_path("bulk-process.py"))
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module) # type: ignore[union-attr]
result = module.main(argv=[str(t.parent), "--no-mp", "--parallelism", "1"])
assert result == 0
def run_program(script_path, args):
args = [sys.executable] + [script_path] + args
logger.debug("running: %r", args)