From 6fc4567f0c49a9b2ea917b285d626a43b4769803 Mon Sep 17 00:00:00 2001 From: Willi Ballenthin Date: Tue, 12 Nov 2024 14:43:32 +0000 Subject: [PATCH] codecut: better graph structure --- scripts/codecut.py | 191 ++++++++++++++++++++++++++++----------------- 1 file changed, 118 insertions(+), 73 deletions(-) diff --git a/scripts/codecut.py b/scripts/codecut.py index 67ffa959..e85d34c7 100644 --- a/scripts/codecut.py +++ b/scripts/codecut.py @@ -230,29 +230,21 @@ def generate_main(args: argparse.Namespace) -> int: db = Assemblage(args.assemblage_database, args.assemblage_directory) - @dataclass - class Function: - file: str - name: str - start_rva: int - end_rva: int - - functions = [ - Function( - file=m.source_file, - name=m.function_name, - start_rva=m.start_rva, - end_rva=m.end_rva, - ) - for m in db.get_rows_by_binary_id(args.binary_id) - ] - pe = db.get_pe_by_binary_id(args.binary_id) base_address: int = pe.OPTIONAL_HEADER.ImageBase + functions_by_address = { + base_address + function.start_rva: function for function in db.get_rows_by_binary_id(args.binary_id) + } + + hash = db.get_row_by_binary_id(args.binary_id).binary_hash + + def make_node_id(address: int) -> str: + return f"{hash}:{address:x}" + pe_path = db.get_path_by_binary_id(args.binary_id) be2: BinExport2 = lancelot.get_binexport2_from_bytes( - pe_path.read_bytes(), function_hints=[base_address + function.start_rva for function in functions] + pe_path.read_bytes(), function_hints=list(functions_by_address.keys()) ) idx = lancelot.be2utils.BinExport2Index(be2) @@ -262,16 +254,16 @@ def generate_main(args: argparse.Namespace) -> int: g = nx.MultiDiGraph() # ensure all functions from ground truth have an entry - for function in functions: + for address, function in functions_by_address.items(): g.add_node( - base_address + function.start_rva, - address=base_address + function.start_rva, + make_node_id(address), + address=address, type="function", ) for flow_graph in be2.flow_graph: datas: set[int] = set() - callees: set[str] = set() + callees: set[int] = set() entry_basic_block_index: int = flow_graph.entry_basic_block_index flow_graph_address: int = idx.get_basic_block_address(entry_basic_block_index) @@ -309,81 +301,127 @@ def generate_main(args: argparse.Namespace) -> int: name = idx.get_function_name_by_vertex(vertex_index) g.add_node( - flow_graph_address, + make_node_id(flow_graph_address), address=flow_graph_address, type="function", ) if datas or callees: logger.info("%s @ 0x%X:", name, flow_graph_address) - for data in sorted(datas): - logger.info(" - 0x%X", data) + for data_address in sorted(datas): + logger.info(" - 0x%X", data_address) g.add_node( - data, - address=data, + make_node_id(data_address), + address=data_address, type="data", ) g.add_edge( - flow_graph_address, - data, + make_node_id(flow_graph_address), + make_node_id(data_address), key="reference", - source_address=flow_graph_address, - destination_address=data, ) for callee in sorted(callees): logger.info(" - %s", idx.get_function_name_by_address(callee)) g.add_node( - callee, + make_node_id(callee), address=callee, type="function", ) g.add_edge( - flow_graph_address, - callee, + make_node_id(flow_graph_address), + make_node_id(callee), key="call", - source_address=flow_graph_address, - destination_address=callee, ) else: logger.info("%s @ 0x%X: (none)", name, flow_graph_address) - for section in pe.sections: - # within each section, emit a neighbor edge for each pair of neighbors. + for node, attrs in g.nodes(data=True): + if attrs["type"] != "function": + continue - section_nodes = [ - node - for node, attrs in g.nodes(data=True) - if (section.VirtualAddress + base_address) - <= attrs["address"] - < (base_address + section.VirtualAddress + section.Misc_VirtualSize) + if f := functions_by_address.get(attrs["address"]): + attrs["name"] = f.function_name + attrs["file"] = f.file_name + + for section in pe.sections: + # Within each section, emit a neighbor edge for each pair of neighbors. + # Neighbors only link nodes of the same type, because assemblage doesn't + # have ground truth for data items, so we don't quite know where to split. + # Consider this situation: + # + # moduleA::func1 + # --- cut --- + # moduleB::func1 + # + # that one is ok, but this is hard: + # + # moduleA::func1 + # --- cut??? --- + # dataZ + # --- or cut here??? --- + # moduleB::func1 + # + # Does the cut go before or after dataZ? + # So, we only have neighbor graphs within functions, and within datas. + # For datas, we don't allow interspersed functions. + + section_nodes = sorted( + [ + (node, attrs) + for node, attrs in g.nodes(data=True) + if (section.VirtualAddress + base_address) + <= attrs["address"] + < (base_address + section.VirtualAddress + section.Misc_VirtualSize) + ], + key=lambda p: p[1]["address"], + ) + + # add neighbor edges between data items. + # the data items must not be separated by any functions. + for i in range(1, len(section_nodes)): + a, a_attrs = section_nodes[i - 1] + b, b_attrs = section_nodes[i] + + if a_attrs["type"] != "data": + continue + + if b_attrs["type"] != "data": + continue + + g.add_edge(a, b, key="neighbor") + g.add_edge(b, a, key="neighbor") + + section_functions = [ + (node, attrs) + for node, attrs in section_nodes + if attrs["type"] == "function" + # we only have ground truth for the known functions + # so only consider those in the function neighbor graph. + and attrs["address"] in functions_by_address ] - for i in range(1, len(section_nodes)): - a = section_nodes[i - 1] - b = section_nodes[i] + # add neighbor edges between functions. + # we drop the potentially interspersed data items before computing these edges. + for i in range(1, len(section_functions)): + a, a_attrs = section_functions[i - 1] + b, b_attrs = section_functions[i] + is_boundary = a_attrs["file"] == b_attrs["file"] - g.add_edge( - a, - b, - key="neighbor", - source_address=a, - destination_address=b, - ) - - for function in functions: - g.nodes[base_address + function.start_rva]["name"] = function.name - g.nodes[base_address + function.start_rva]["file"] = function.file + g.add_edge(a, b, key="neighbor", is_source_file_boundary=is_boundary) + g.add_edge(b, a, key="neighbor", is_source_file_boundary=is_boundary) # rename unknown functions like: sub_401000 for n, attrs in g.nodes(data=True): if attrs["type"] != "function": continue + if "name" in attrs: continue - attrs["name"] = f"sub_{n:x}" + + attrs["name"] = f"sub_{attrs['address']:x}" # assign human-readable repr to add nodes # assign is_import=bool to functions @@ -394,11 +432,11 @@ def generate_main(args: argparse.Namespace) -> int: attrs["repr"] = attrs["name"] attrs["is_import"] = "!" in attrs["name"] case "data": - if string := read_string(address_space, n): + if string := read_string(address_space, attrs["address"]): attrs["repr"] = json.dumps(string) attrs["is_string"] = True else: - attrs["repr"] = f"data_{n:x}" + attrs["repr"] = f"data_{attrs['address']:x}" attrs["is_string"] = False for line in nx.generate_gexf(g): @@ -441,21 +479,26 @@ def generate_all_main(args: argparse.Namespace) -> int: db = Assemblage(args.assemblage_database, args.assemblage_directory) - output_directory: Path = args.output_directory binary_ids = list(db.get_binary_ids()) with Pool(args.num_workers) as p: - _ = list(p.imap_unordered( - _worker, - (( - args.assemblage_database, - args.assemblage_directory, - args.output_directory / str(binary_id) / "graph.gexf", - binary_id - ) - for binary_id in binary_ids)) + _ = list( + p.imap_unordered( + _worker, + ( + ( + args.assemblage_database, + args.assemblage_directory, + args.output_directory / str(binary_id) / "graph.gexf", + binary_id, + ) + for binary_id in binary_ids + ), + ) ) + return 0 + def cluster_main(args: argparse.Namespace) -> int: if not args.graph.is_file(): @@ -489,13 +532,15 @@ def main(argv=None) -> int: generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect") generate_parser.set_defaults(func=generate_main) - num_cores = os.cpu_count() + num_cores = os.cpu_count() or 1 default_workers = max(1, num_cores - 2) generate_all_parser = subparsers.add_parser("generate_all", help="generate graphs for all samples") generate_all_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database") generate_all_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory") generate_all_parser.add_argument("output_directory", type=Path, help="path to output directory") - generate_all_parser.add_argument("--num_workers", type=int, default=default_workers, help="number of workers to use") + generate_all_parser.add_argument( + "--num_workers", type=int, default=default_workers, help="number of workers to use" + ) generate_all_parser.set_defaults(func=generate_all_main) cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph")