codecut: bulk generate graphs

This commit is contained in:
Willi Ballenthin
2024-11-06 13:55:13 +00:00
parent f296e7d423
commit 3b1a8f5b5a

View File

@@ -1,11 +1,14 @@
import os
import sys
import json
import logging
import sqlite3
import argparse
import subprocess
from typing import Iterator, Optional
from pathlib import Path
from dataclasses import dataclass
from multiprocessing import Pool
import pefile
import lancelot
@@ -52,6 +55,9 @@ def compute_thunks(be2: BinExport2, idx: BinExport2Index) -> dict[int, int]:
# that we can't actually resolve here.
break
if len(thunk_callees) != 1:
for thunk_callee in thunk_callees:
logger.warning("%s", hex(be2.call_graph.vertex[thunk_callee].address))
assert len(thunk_callees) == 1, f"thunk @ {hex(addr)} failed"
thunked_vertex_idx: int = thunk_callees[0]
@@ -209,6 +215,14 @@ class Assemblage:
path = self.get_path_by_binary_id(binary_id)
return pefile.PE(data=path.read_bytes(), fast_load=True)
def get_binary_ids(self) -> Iterator[int]:
with self.conn:
cur = self.conn.execute("SELECT DISTINCT binary_id FROM assemblage ORDER BY binary_id ASC;")
row = cur.fetchone()
while row:
yield row[0]
row = cur.fetchone()
def generate_main(args: argparse.Namespace) -> int:
if not args.assemblage_database.is_file():
@@ -247,6 +261,14 @@ def generate_main(args: argparse.Namespace) -> int:
g = nx.MultiDiGraph()
# ensure all functions from ground truth have an entry
for function in functions:
g.add_node(
base_address + function.start_rva,
address=base_address + function.start_rva,
type="function",
)
for flow_graph in be2.flow_graph:
datas: set[int] = set()
callees: set[str] = set()
@@ -386,6 +408,55 @@ def generate_main(args: argparse.Namespace) -> int:
return 0
def _worker(args):
assemblage_database: Path
assemblage_directory: Path
graph_file: Path
binary_id: int
(assemblage_database, assemblage_directory, graph_file, binary_id) = args
if graph_file.is_file():
return
logger.info("processing: %d", binary_id)
process = subprocess.run(
["python", __file__, "--debug", "generate", assemblage_database, assemblage_directory, str(binary_id)],
capture_output=True,
encoding="utf-8",
)
if process.returncode != 0:
logger.warning("failed: %d", binary_id)
logger.debug("%s", process.stderr)
return
graph_file.parent.mkdir(exist_ok=True)
graph = process.stdout
graph_file.write_text(graph)
def generate_all_main(args: argparse.Namespace) -> int:
if not args.assemblage_database.is_file():
raise ValueError("database doesn't exist")
db = Assemblage(args.assemblage_database, args.assemblage_directory)
output_directory: Path = args.output_directory
binary_ids = list(db.get_binary_ids())
with Pool(args.num_workers) as p:
_ = list(p.imap_unordered(
_worker,
((
args.assemblage_database,
args.assemblage_directory,
args.output_directory / str(binary_id) / "graph.gexf",
binary_id
)
for binary_id in binary_ids))
)
def cluster_main(args: argparse.Namespace) -> int:
if not args.graph.is_file():
raise ValueError("graph file doesn't exist")
@@ -418,6 +489,15 @@ def main(argv=None) -> int:
generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect")
generate_parser.set_defaults(func=generate_main)
num_cores = os.cpu_count()
default_workers = max(1, num_cores - 2)
generate_all_parser = subparsers.add_parser("generate_all", help="generate graphs for all samples")
generate_all_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database")
generate_all_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory")
generate_all_parser.add_argument("output_directory", type=Path, help="path to output directory")
generate_all_parser.add_argument("--num_workers", type=int, default=default_workers, help="number of workers to use")
generate_all_parser.set_defaults(func=generate_all_main)
cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph")
cluster_parser.add_argument("graph", type=Path, help="path to a graph file")
cluster_parser.set_defaults(func=cluster_main)