diff --git a/CHANGELOG.md b/CHANGELOG.md index 36c82d80..7801546f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ - fix: Scopes.from_dict uses cls instead of self so subclasses return the correct type @williballenthin - fix: correct wrong dict key in VMRay _compute_monitor_threads assertion (used thread_id instead of process_id) @williballenthin - fix: replace assert with isinstance guard in get_callee for invalid MethodSpec tokens @williballenthin +- fix: dedent bulk-process.py main() body so explicit argv argument is used instead of silently ignored @williballenthin (SURF-90) - fix: guard statistics.quantiles/mean in compare-backends.py report() against empty duration lists @williballenthin (SURF-89) - fix: replace zipfile with pyzipper in minimize_vmray_results.py so output archive is AES-encrypted @williballenthin (SURF-88) - fix: assign yara_strings/yara_condition to empty string when Some has cmin=0 to prevent UnboundLocalError @williballenthin (SURF-87) diff --git a/scripts/bulk-process.py b/scripts/bulk-process.py index 12d64fed..936bfc87 100644 --- a/scripts/bulk-process.py +++ b/scripts/bulk-process.py @@ -160,69 +160,77 @@ def main(argv=None): if argv is None: argv = sys.argv[1:] - parser = argparse.ArgumentParser(description="detect capabilities in programs.") - capa.main.install_common_args(parser, wanted={"rules", "signatures", "format", "os", "backend"}) - parser.add_argument("input_directory", type=str, help="Path to directory of files to recursively analyze") - parser.add_argument( - "-n", "--parallelism", type=int, default=multiprocessing.cpu_count(), help="parallelism factor" - ) - parser.add_argument("--no-mp", action="store_true", help="disable subprocesses") - args = parser.parse_args(args=argv) + parser = argparse.ArgumentParser(description="detect capabilities in programs.") + capa.main.install_common_args(parser, wanted={"rules", "signatures", "format", "os", "backend"}) + parser.add_argument("input_directory", type=str, help="Path to directory of files to recursively analyze") + parser.add_argument("-n", "--parallelism", type=int, default=multiprocessing.cpu_count(), help="parallelism factor") + parser.add_argument("--no-mp", action="store_true", help="disable subprocesses") + args = parser.parse_args(args=argv) - samples = [] - for file in Path(args.input_directory).rglob("*"): - samples.append(file) + samples = [] + for file in Path(args.input_directory).rglob("*"): + samples.append(file) - cpu_count = multiprocessing.cpu_count() + cpu_count = multiprocessing.cpu_count() - def pmap(f, args, parallelism=cpu_count): - """apply the given function f to the given args using subprocesses""" - return multiprocessing.Pool(parallelism).imap(f, args) + def pmap(f, args, parallelism=cpu_count): + """apply the given function f to the given args using subprocesses""" + return multiprocessing.Pool(parallelism).imap(f, args) - def tmap(f, args, parallelism=cpu_count): - """apply the given function f to the given args using threads""" - return multiprocessing.pool.ThreadPool(parallelism).imap(f, args) + def tmap(f, args, parallelism=cpu_count): + """apply the given function f to the given args using threads""" + return multiprocessing.pool.ThreadPool(parallelism).imap(f, args) - def map(f, args, parallelism=None): - """apply the given function f to the given args in the current thread""" - for arg in args: - yield f(arg) + def map(f, args, parallelism=None): + """apply the given function f to the given args in the current thread""" + for arg in args: + yield f(arg) - if args.no_mp: - if args.parallelism == 1: - logger.debug("using current thread mapper") - mapper = map - else: - logger.debug("using threading mapper") - mapper = tmap + if args.no_mp: + if args.parallelism == 1: + logger.debug("using current thread mapper") + mapper = map else: - logger.debug("using process mapper") - mapper = pmap + logger.debug("using threading mapper") + mapper = tmap + else: + logger.debug("using process mapper") + mapper = pmap - rules = args.rules - if rules == [capa.main.RULES_PATH_DEFAULT_STRING]: - rules = None + rules = args.rules + if rules == [capa.main.RULES_PATH_DEFAULT_STRING]: + rules = None - results = {} - for result in mapper( - get_capa_results, - [(rules, args.signatures, args.format, args.backend, args.os, str(sample)) for sample in samples], - parallelism=args.parallelism, - ): - if result["status"] == "error": - logger.warning(result["error"]) - elif result["status"] == "ok": - doc = rd.ResultDocument.model_validate(result["ok"]).model_dump_json(exclude_none=True) - results[result["path"]] = json.loads(doc) + results = {} + for result in mapper( + get_capa_results, + [ + ( + rules, + args.signatures, + args.format, + args.backend, + args.os, + str(sample), + ) + for sample in samples + ], + parallelism=args.parallelism, + ): + if result["status"] == "error": + logger.warning(result["error"]) + elif result["status"] == "ok": + doc = rd.ResultDocument.model_validate(result["ok"]).model_dump_json(exclude_none=True) + results[result["path"]] = json.loads(doc) - else: - raise ValueError(f"unexpected status: {result['status']}") + else: + raise ValueError(f"unexpected status: {result['status']}") - print(json.dumps(results)) + print(json.dumps(results)) - logger.info("done.") + logger.info("done.") - return 0 + return 0 if __name__ == "__main__": diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 0e679a7a..9a54f472 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -119,6 +119,26 @@ def test_bulk_process(tmp_path): assert p.returncode == 0 +def test_bulk_process_explicit_argv(tmp_path): + import importlib.util + + t = tmp_path / "test" + t.mkdir() + + source_file = Path(__file__).resolve().parent / "data" / "ping_täst.exe_" + dest_file = t / "test.exe_" + dest_file.write_bytes(source_file.read_bytes()) + + spec = importlib.util.spec_from_file_location("bulk_process", get_script_path("bulk-process.py")) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) # type: ignore[union-attr] + + result = module.main(argv=[str(t.parent), "--no-mp", "--parallelism", "1"]) + assert result == 0 + + def run_program(script_path, args): args = [sys.executable] + [script_path] + args logger.debug("running: %r", args)