diff --git a/scripts/codecut.py b/scripts/codecut.py index 2b12fe38..67ffa959 100644 --- a/scripts/codecut.py +++ b/scripts/codecut.py @@ -1,11 +1,14 @@ +import os import sys import json import logging import sqlite3 import argparse +import subprocess from typing import Iterator, Optional from pathlib import Path from dataclasses import dataclass +from multiprocessing import Pool import pefile import lancelot @@ -52,6 +55,9 @@ def compute_thunks(be2: BinExport2, idx: BinExport2Index) -> dict[int, int]: # that we can't actually resolve here. break + if len(thunk_callees) != 1: + for thunk_callee in thunk_callees: + logger.warning("%s", hex(be2.call_graph.vertex[thunk_callee].address)) assert len(thunk_callees) == 1, f"thunk @ {hex(addr)} failed" thunked_vertex_idx: int = thunk_callees[0] @@ -209,6 +215,14 @@ class Assemblage: path = self.get_path_by_binary_id(binary_id) return pefile.PE(data=path.read_bytes(), fast_load=True) + def get_binary_ids(self) -> Iterator[int]: + with self.conn: + cur = self.conn.execute("SELECT DISTINCT binary_id FROM assemblage ORDER BY binary_id ASC;") + row = cur.fetchone() + while row: + yield row[0] + row = cur.fetchone() + def generate_main(args: argparse.Namespace) -> int: if not args.assemblage_database.is_file(): @@ -247,6 +261,14 @@ def generate_main(args: argparse.Namespace) -> int: g = nx.MultiDiGraph() + # ensure all functions from ground truth have an entry + for function in functions: + g.add_node( + base_address + function.start_rva, + address=base_address + function.start_rva, + type="function", + ) + for flow_graph in be2.flow_graph: datas: set[int] = set() callees: set[str] = set() @@ -386,6 +408,55 @@ def generate_main(args: argparse.Namespace) -> int: return 0 +def _worker(args): + + assemblage_database: Path + assemblage_directory: Path + graph_file: Path + binary_id: int + + (assemblage_database, assemblage_directory, graph_file, binary_id) = args + if graph_file.is_file(): + return + + logger.info("processing: %d", binary_id) + process = subprocess.run( + ["python", __file__, "--debug", "generate", assemblage_database, assemblage_directory, str(binary_id)], + capture_output=True, + encoding="utf-8", + ) + if process.returncode != 0: + logger.warning("failed: %d", binary_id) + logger.debug("%s", process.stderr) + return + + graph_file.parent.mkdir(exist_ok=True) + graph = process.stdout + graph_file.write_text(graph) + + +def generate_all_main(args: argparse.Namespace) -> int: + if not args.assemblage_database.is_file(): + raise ValueError("database doesn't exist") + + db = Assemblage(args.assemblage_database, args.assemblage_directory) + + output_directory: Path = args.output_directory + binary_ids = list(db.get_binary_ids()) + + with Pool(args.num_workers) as p: + _ = list(p.imap_unordered( + _worker, + (( + args.assemblage_database, + args.assemblage_directory, + args.output_directory / str(binary_id) / "graph.gexf", + binary_id + ) + for binary_id in binary_ids)) + ) + + def cluster_main(args: argparse.Namespace) -> int: if not args.graph.is_file(): raise ValueError("graph file doesn't exist") @@ -418,6 +489,15 @@ def main(argv=None) -> int: generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect") generate_parser.set_defaults(func=generate_main) + num_cores = os.cpu_count() + default_workers = max(1, num_cores - 2) + generate_all_parser = subparsers.add_parser("generate_all", help="generate graphs for all samples") + generate_all_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database") + generate_all_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory") + generate_all_parser.add_argument("output_directory", type=Path, help="path to output directory") + generate_all_parser.add_argument("--num_workers", type=int, default=default_workers, help="number of workers to use") + generate_all_parser.set_defaults(func=generate_all_main) + cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph") cluster_parser.add_argument("graph", type=Path, help="path to a graph file") cluster_parser.set_defaults(func=cluster_main)