add codemap script

This commit is contained in:
Willi Ballenthin
2025-04-25 20:45:36 +02:00
parent 9d3d3be21d
commit d0bafd6ab7

490
scripts/codemap.py Normal file
View File

@@ -0,0 +1,490 @@
#!/usr/bin/env python
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "protobuf",
# "python-lancelot",
# "rich",
# ]
# ///
#
# TODO:
# - ignore stack cookie check
import sys
import json
import time
import logging
import argparse
import contextlib
from typing import Any
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass
import lancelot
import rich.padding
import lancelot.be2utils
import google.protobuf.message
from rich.text import Text
from rich.theme import Theme
from rich.markup import escape
from rich.console import Console
from lancelot.be2utils.binexport2_pb2 import BinExport2
logger = logging.getLogger("codemap")
@contextlib.contextmanager
def timing(msg: str):
t0 = time.time()
yield
t1 = time.time()
logger.debug("perf: %s: %0.2fs", msg, t1 - t0)
class Renderer:
def __init__(self, console: Console):
self.console: Console = console
self.indent: int = 0
@contextlib.contextmanager
def indenting(self):
self.indent += 1
try:
yield
finally:
self.indent -= 1
@staticmethod
def markup(s: str, **kwargs) -> Text:
escaped_args = {k: (escape(v) if isinstance(v, str) else v) for k, v in kwargs.items()}
return Text.from_markup(s.format(**escaped_args))
def print(self, renderable, **kwargs):
if not kwargs:
return self.console.print(rich.padding.Padding(renderable, (0, 0, 0, self.indent * 2)))
assert isinstance(renderable, str)
return self.print(self.markup(renderable, **kwargs))
def writeln(self, s: str):
self.print(s)
@contextlib.contextmanager
def section(self, name):
if isinstance(name, str):
self.print("[title]{name}", name=name)
elif isinstance(name, Text):
name = name.copy()
name.stylize_before(self.console.get_style("title"))
self.print(name)
else:
raise ValueError("unexpected section name")
with self.indenting():
yield
@dataclass
class AssemblageLocation:
name: str
file: str
prototype: str
rva: int
@property
def path(self):
if not self.file.endswith(")"):
return self.file
return self.file.rpartition(" (")[0]
@classmethod
def from_dict(cls, data: dict[str, Any]):
return cls(
name=data["name"],
file=data["file"],
prototype=data["prototype"],
rva=data["function_start"],
)
@staticmethod
def from_json(doc: str):
return AssemblageLocation.from_dict(json.loads(doc))
def main(argv: list[str] | None = None):
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="Inspect BinExport2 files")
parser.add_argument("input_file", type=Path, help="path to input file")
parser.add_argument("--capa", type=Path, help="path to capa JSON results file")
parser.add_argument("--assemblage", type=Path, help="path to Assemblage JSONL file")
parser.add_argument("-d", "--debug", action="store_true", help="enable debugging output on STDERR")
parser.add_argument("-q", "--quiet", action="store_true", help="disable all output but errors")
args = parser.parse_args(args=argv)
logging.basicConfig()
if args.quiet:
logging.getLogger().setLevel(logging.WARNING)
elif args.debug:
logging.getLogger().setLevel(logging.DEBUG)
else:
logging.getLogger().setLevel(logging.INFO)
theme = Theme(
{
"decoration": "grey54",
"title": "yellow",
"key": "black",
"value": "blue",
"default": "black",
},
inherit=False,
)
console = Console(theme=theme, markup=False, emoji=False)
o = Renderer(console)
be2: BinExport2
buf: bytes
try:
# easiest way to determine if this is a BinExport2 proto is...
# to just try to decode it.
buf = args.input_file.read_bytes()
with timing("loading BinExport2"):
be2 = BinExport2()
be2.ParseFromString(buf)
except google.protobuf.message.DecodeError:
with timing("analyzing file"):
input_file: Path = args.input_file
buf = lancelot.get_binexport2_bytes_from_bytes(input_file.read_bytes())
with timing("loading BinExport2"):
be2 = BinExport2()
be2.ParseFromString(buf)
with timing("indexing BinExport2"):
idx = lancelot.be2utils.BinExport2Index(be2)
matches_by_function: defaultdict[int, set[str]] = defaultdict(set)
if args.capa:
with timing("loading capa"):
doc = json.loads(args.capa.read_text())
functions_by_basic_block: dict[int, int] = {}
for function in doc["meta"]["analysis"]["layout"]["functions"]:
for basic_block in function["matched_basic_blocks"]:
functions_by_basic_block[basic_block["address"]["value"]] = function["address"]["value"]
matches_by_address: defaultdict[int, set[str]] = defaultdict(set)
for rule_name, results in doc["rules"].items():
for location, _ in results["matches"]:
if location["type"] != "absolute":
continue
address = location["value"]
matches_by_address[location["value"]].add(rule_name)
for address, matches in matches_by_address.items():
if function := functions_by_basic_block.get(address):
if function in idx.thunks:
# forward any capa for a thunk to its target
# since viv may not recognize the thunk as a separate function.
logger.debug("forwarding capa matches from thunk 0x%x to 0x%x", function, idx.thunks[function])
function = idx.thunks[function]
matches_by_function[function].update(matches)
for match in matches:
logger.info("capa: 0x%x: %s", function, match)
else:
# we don't know which function this is.
# hopefully its a function recognized in our BinExport analysis.
# *shrug*
#
# apparently viv doesn't emit function entries for thunks?
# or somehow our layout is messed up.
if address in idx.thunks:
# forward any capa for a thunk to its target
# since viv may not recognize the thunk as a separate function.
logger.debug("forwarding capa matches from thunk 0x%x to 0x%x", address, idx.thunks[address])
address = idx.thunks[address]
# since we found the thunk, we know this is a BinExport-recognized function.
# so thats nice.
for match in matches:
logger.info("capa: 0x%x: %s", address, match)
else:
logger.warning("unknown address: 0x%x: %s", address, matches)
matches_by_function[address].update(matches)
# guess the base address (which BinExport2) does not track explicitly,
# by assuming it is the lowest mapped page.
base_address = min(map(lambda section: section.address, be2.section))
logging.info("guessed base address: 0x%x", base_address)
assemblage_locations_by_va: dict[int, AssemblageLocation] = {}
if args.assemblage:
with timing("loading assemblage"):
with args.assemblage.open("rt", encoding="utf-8") as f:
for line in f:
if not line:
continue
location = AssemblageLocation.from_json(line)
assemblage_locations_by_va[base_address + location.rva] = location
# update function names for the in-memory BinExport2 using Assemblage data.
# this won't affect the be2 on disk, because we don't serialize it back out.
for address, location in assemblage_locations_by_va.items():
if not location.name:
continue
if vertex_index := idx.vertex_index_by_address.get(address):
vertex = be2.call_graph.vertex[vertex_index].demangled_name = location.name
# index all the callers of each function, resolving thunks.
# idx.callers_by_vertex_id does not resolve thunks.
resolved_callers_by_vertex_id = defaultdict(set)
for edge in be2.call_graph.edge:
source_index = edge.source_vertex_index
if lancelot.be2utils.is_thunk_vertex(be2.call_graph.vertex[source_index]):
# we don't care about the callers that are thunks.
continue
if lancelot.be2utils.is_thunk_vertex(be2.call_graph.vertex[edge.target_vertex_index]):
thunk_vertex = be2.call_graph.vertex[edge.target_vertex_index]
thunk_address = thunk_vertex.address
target_address = idx.thunks[thunk_address]
target_index = idx.vertex_index_by_address[target_address]
logger.debug(
"call %s -(thunk)-> %s",
idx.get_function_name_by_vertex(source_index),
idx.get_function_name_by_vertex(target_index),
)
else:
target_index = edge.target_vertex_index
logger.debug(
"call %s -> %s",
idx.get_function_name_by_vertex(source_index),
idx.get_function_name_by_vertex(target_index),
)
resolved_callers_by_vertex_id[target_index].add(source_index)
t0 = time.time()
with o.section("meta"):
o.writeln(f"name: {be2.meta_information.executable_name}")
o.writeln(f"sha256: {be2.meta_information.executable_id}")
o.writeln(f"arch: {be2.meta_information.architecture_name}")
o.writeln(f"ts: {be2.meta_information.timestamp}")
with o.section("modules"):
for module in be2.module:
o.writeln(f"- {module.name}")
if not be2.module:
o.writeln("(none)")
with o.section("sections"):
for section in be2.section:
perms = ""
perms += "r" if section.flag_r else "-"
perms += "w" if section.flag_w else "-"
perms += "x" if section.flag_x else "-"
o.writeln(f"- {hex(section.address)} {perms} {hex(section.size)}")
with o.section("libraries"):
for library in be2.library:
o.writeln(
f"- {library.name:<12s} {'(static)' if library.is_static else ''}{(' at ' + hex(library.load_address)) if library.HasField('load_address') else ''}"
)
if not be2.library:
o.writeln("(none)")
vertex_order_by_address = {address: i for (i, address) in enumerate(idx.vertex_index_by_address.keys())}
with o.section("functions"):
last_address = None
for _, vertex_index in idx.vertex_index_by_address.items():
vertex = be2.call_graph.vertex[vertex_index]
vertex_order = vertex_order_by_address[vertex.address]
if vertex.HasField("library_index"):
continue
if vertex.HasField("module_index"):
continue
function_name = idx.get_function_name_by_vertex(vertex_index)
if last_address:
try:
last_path = assemblage_locations_by_va[last_address].path
path = assemblage_locations_by_va[vertex.address].path
if last_path != path:
o.print(o.markup("[blue]~~~~~~~~~~~~~~~~~~~~~~~~~~~~~[/] [title]file[/] {path}\n", path=path))
except KeyError:
pass
last_address = vertex.address
if lancelot.be2utils.is_thunk_vertex(vertex):
with o.section(
o.markup(
"thunk [default]{function_name}[/] [decoration]@ {function_address}[/]",
function_name=function_name,
function_address=hex(vertex.address),
)
):
continue
with o.section(
o.markup(
"function [default]{function_name}[/] [decoration]@ {function_address}[/]",
function_name=function_name,
function_address=hex(vertex.address),
)
):
if vertex.address in idx.thunks:
o.writeln("")
continue
# keep the xrefs separate from the calls, since they're visually hard to distinguish.
# use local index of callers that has resolved intermediate thunks,
# since they are sometimes stored in a physically distant location.
for caller_index in resolved_callers_by_vertex_id.get(vertex_index, []):
caller_vertex = be2.call_graph.vertex[caller_index]
caller_order = vertex_order_by_address[caller_vertex.address]
caller_delta = caller_order - vertex_order
if caller_delta < 0:
direction = ""
else:
direction = ""
o.print(
"xref: [decoration]{direction}[/] {name} [decoration]({delta:+})[/]",
direction=direction,
name=idx.get_function_name_by_vertex(caller_index),
delta=caller_delta,
)
if vertex.address not in idx.flow_graph_index_by_address:
num_basic_blocks = 0
num_instructions = 0
num_edges = 0
total_instruction_size = 0
else:
flow_graph_index = idx.flow_graph_index_by_address[vertex.address]
flow_graph = be2.flow_graph[flow_graph_index]
num_basic_blocks = len(flow_graph.basic_block_index)
num_instructions = sum(
len(list(idx.instruction_indices(be2.basic_block[bb_idx])))
for bb_idx in flow_graph.basic_block_index
)
num_edges = len(flow_graph.edge)
total_instruction_size = 0
for bb_idx in flow_graph.basic_block_index:
basic_block = be2.basic_block[bb_idx]
for _, instruction, _ in idx.basic_block_instructions(basic_block):
total_instruction_size += len(instruction.raw_bytes)
o.writeln(
f"B/E/I: {num_basic_blocks} / {num_edges} / {num_instructions} ({total_instruction_size} bytes)"
)
for match in matches_by_function.get(vertex.address, []):
o.writeln(f"capa: {match}")
if vertex.address in idx.flow_graph_index_by_address:
flow_graph_index = idx.flow_graph_index_by_address[vertex.address]
flow_graph = be2.flow_graph[flow_graph_index]
seen_callees = set()
for basic_block_index in flow_graph.basic_block_index:
basic_block = be2.basic_block[basic_block_index]
for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block):
if instruction.call_target:
for call_target_address in instruction.call_target:
if call_target_address in idx.thunks:
call_target_address = idx.thunks[call_target_address]
call_target_index = idx.vertex_index_by_address[call_target_address]
call_target_vertex = be2.call_graph.vertex[call_target_index]
if call_target_vertex.HasField("library_index"):
continue
if call_target_vertex.address in seen_callees:
continue
seen_callees.add(call_target_vertex.address)
call_target_order = vertex_order_by_address[call_target_address]
call_target_delta = call_target_order - vertex_order
call_target_name = idx.get_function_name_by_address(call_target_address)
if call_target_delta < 0:
direction = ""
else:
direction = ""
o.print(
"calls: [decoration]{direction}[/] {name} [decoration]({delta:+})[/]",
direction=direction,
name=call_target_name,
delta=call_target_delta,
)
for basic_block_index in flow_graph.basic_block_index:
basic_block = be2.basic_block[basic_block_index]
for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block):
if instruction.call_target:
for call_target_address in instruction.call_target:
call_target_index = idx.vertex_index_by_address[call_target_address]
call_target_vertex = be2.call_graph.vertex[call_target_index]
if not call_target_vertex.HasField("library_index"):
continue
if call_target_vertex.address in seen_callees:
continue
seen_callees.add(call_target_vertex.address)
call_target_name = idx.get_function_name_by_address(call_target_address)
o.print(
"api: {name}",
name=call_target_name,
)
seen_strings = set()
for basic_block_index in flow_graph.basic_block_index:
basic_block = be2.basic_block[basic_block_index]
for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block):
if instruction_index in idx.string_reference_index_by_source_instruction_index:
for string_reference_index in idx.string_reference_index_by_source_instruction_index[
instruction_index
]:
string_reference = be2.string_reference[string_reference_index]
string_index = string_reference.string_table_index
string = be2.string_table[string_index]
if string in seen_strings:
continue
seen_strings.add(string)
o.print(
'string: [decoration]"[/]{string}[decoration]"[/]',
string=string.rstrip(),
)
o.print("")
t1 = time.time()
logger.debug("perf: rendering BinExport2: %0.2fs", t1 - t0)
if __name__ == "__main__":
sys.exit(main())