Compare commits

..

3 Commits

2 changed files with 105 additions and 194 deletions

View File

@@ -11,8 +11,9 @@ import logging
import argparse
import tempfile
import contextlib
import collections
from enum import Enum
from typing import List, Optional
from typing import List, Iterable, Optional
from pathlib import Path
import rich
@@ -51,6 +52,7 @@ class Method(str, Enum):
STRINGS = "strings"
THUNK = "thunk"
ENTRYPOINT = "entrypoint"
CALLGRAPH = "callgraph"
class FunctionClassification(BaseModel):
@@ -69,9 +71,19 @@ class FunctionClassification(BaseModel):
library_name: Optional[str] = None
library_version: Optional[str] = None
# additional note on the classification, TODO removeme if not useful beyond dev/debug
note: Optional[str] = None
class BinaryLayout(BaseModel):
va: int
# size of the function chunks in bytes
size: int
class FunctionIdResults(BaseModel):
function_classifications: List[FunctionClassification]
layout: List[BinaryLayout]
@contextlib.contextmanager
@@ -110,11 +122,63 @@ def ida_session(input_path: Path, use_temp_dir=True):
t.unlink()
def get_library_called_functions(
function_classifications: list[FunctionClassification],
) -> Iterable[FunctionClassification]:
MAX_PASSES = 10
classifications_by_va = capa.analysis.strings.create_index(function_classifications, "va")
for n in range(MAX_PASSES):
found_new_lib_func = False
for fva in idautils.Functions():
if classifications_by_va.get(fva):
# already classified
continue
for ref in idautils.CodeRefsTo(fva, True):
f: idaapi.func_t = idaapi.get_func(ref)
if not f:
# no function associated with reference location
continue
ref_fva = f.start_ea
fname = idaapi.get_func_name(ref_fva)
if fname in ("___tmainCRTStartup",):
# ignore library functions, where we know that they call user-code
# TODO(mr): extend this list
continue
if classifications := classifications_by_va.get(ref_fva):
for c in classifications:
if c.classification == Classification.LIBRARY:
fc = FunctionClassification(
va=fva,
name=idaapi.get_func_name(fva),
classification=Classification.LIBRARY,
method=Method.CALLGRAPH,
note=f"called by 0x{ref_fva:x} ({c.method.value}{f', {c.library_name}@{c.library_version})' if c.library_name else ')'}",
)
classifications_by_va[fva].append(fc)
yield fc
found_new_lib_func = True
break
if not found_new_lib_func:
logger.debug("no update in pass %d, done here", n)
return
def is_thunk_function(fva):
f = idaapi.get_func(fva)
return bool(f.flags & idaapi.FUNC_THUNK)
def get_function_size(fva):
f = idaapi.get_func(fva)
assert f.start_ea == fva
return sum([end_ea - start_ea for (start_ea, end_ea) in idautils.Chunks(fva)])
def main(argv=None):
if argv is None:
argv = sys.argv[1:]
@@ -177,7 +241,11 @@ def main(argv=None):
for va in idautils.Functions():
name = idaapi.get_func_name(va)
if name not in {"WinMain", }:
if name not in {
"WinMain",
"_main",
"main",
}:
continue
function_classifications.append(
@@ -189,7 +257,11 @@ def main(argv=None):
)
)
doc = FunctionIdResults(function_classifications=[])
with capa.main.timing("call graph based library identification"):
for fc in get_library_called_functions(function_classifications):
function_classifications.append(fc)
doc = FunctionIdResults(function_classifications=[], layout=[])
classifications_by_va = capa.analysis.strings.create_index(function_classifications, "va")
for va in idautils.Functions():
if classifications := classifications_by_va.get(va):
@@ -203,20 +275,31 @@ def main(argv=None):
method=None,
)
)
doc.layout.append(
BinaryLayout(
va=va,
size=get_function_size(va),
)
)
if args.json:
print(doc.model_dump_json()) # noqa: T201 print found
else:
table = rich.table.Table()
table.add_column("FVA")
table.add_column("CLASSIFICATION")
table.add_column("METHOD")
table.add_column("FNAME")
table.add_column("EXTRA INFO")
table = rich.table.Table(
"FVA",
"CLASSIFICATION",
"METHOD",
"FNAME",
"EXTRA",
"SIZE"
)
classifications_by_va = capa.analysis.strings.create_index(doc.function_classifications, "va", sorted_=True)
size_by_va = {layout.va: layout.size for layout in doc.layout}
size_by_classification = collections.defaultdict(int)
for va, classifications in classifications_by_va.items():
# TODO count of classifications if multiple?
name = ", ".join({c.name for c in classifications})
if "sub_" in name:
name = Text(name, style="grey53")
@@ -224,17 +307,29 @@ def main(argv=None):
classification = {c.classification for c in classifications}
method = {c.method for c in classifications if c.method}
extra = {f"{c.library_name}@{c.library_version}" for c in classifications if c.library_name}
note = {f"{c.note}" for c in classifications if c.note}
table.add_row(
hex(va),
", ".join(classification) if classification != {"unknown"} else Text("unknown", style="grey53"),
", ".join(method),
name,
", ".join(extra),
f"{', '.join(extra)} {', '.join(note)}",
f"{size_by_va[va]}",
)
size_by_classification["-".join(classification)] += size_by_va[va]
rich.print(table)
stats_table = rich.table.Table(
"ID", rich.table.Column("SIZE", justify="right"), rich.table.Column("%", justify="right")
)
size_all = sum(size_by_classification.values())
for k, s in size_by_classification.items():
stats_table.add_row(k, f"{s:d}", f"{100 * s / size_all:.2f}")
rich.print(stats_table)
if __name__ == "__main__":
sys.exit(main())

View File

@@ -1,184 +0,0 @@
import sys
import sqlite3
import argparse
from pathlib import Path
from dataclasses import dataclass
import pefile
import capa.main
@dataclass
class AssemblageRow:
# from table: binaries
binary_id: int
file_name: str
platform: str
build_mode: str
toolset_version: str
github_url: str
optimization: str
repo_last_update: int
size: int
path: str
license: str
binary_hash: str
repo_commit_hash: str
# from table: functions
function_id: int
function_name: str
function_hash: str
top_comments: str
source_codes: str
prototype: str
_source_file: str
# from table: rvas
rva_id: int
start_rva: int
end_rva: int
@property
def source_file(self):
# cleanup some extra metadata provided by assemblage
return self._source_file.partition(" (MD5: ")[0].partition(" (0x3: ")[0]
class Assemblage:
conn: sqlite3.Connection
samples: Path
def __init__(self, db: Path, samples: Path):
super().__init__()
self.db = db
self.samples = samples
self.conn = sqlite3.connect(self.db)
with self.conn:
self.conn.executescript("""
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA busy_timeout = 5000;
PRAGMA cache_size = -20000; -- 20MB
PRAGMA foreign_keys = true;
PRAGMA temp_store = memory;
BEGIN IMMEDIATE TRANSACTION;
CREATE INDEX IF NOT EXISTS idx__functions__binary_id ON functions (binary_id);
CREATE INDEX IF NOT EXISTS idx__rvas__function_id ON rvas (function_id);
CREATE VIEW IF NOT EXISTS assemblage AS
SELECT
binaries.id AS binary_id,
binaries.file_name AS file_name,
binaries.platform AS platform,
binaries.build_mode AS build_mode,
binaries.toolset_version AS toolset_version,
binaries.github_url AS github_url,
binaries.optimization AS optimization,
binaries.repo_last_update AS repo_last_update,
binaries.size AS size,
binaries.path AS path,
binaries.license AS license,
binaries.hash AS hash,
binaries.repo_commit_hash AS repo_commit_hash,
functions.id AS function_id,
functions.name AS function_name,
functions.hash AS function_hash,
functions.top_comments AS top_comments,
functions.source_codes AS source_codes,
functions.prototype AS prototype,
functions.source_file AS source_file,
rvas.id AS rva_id,
rvas.start AS start_rva,
rvas.end AS end_rva
FROM binaries
JOIN functions ON binaries.id = functions.binary_id
JOIN rvas ON functions.id = rvas.function_id;
""")
def get_row_by_binary_id(self, binary_id: int) -> AssemblageRow:
with self.conn:
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id, ))
return AssemblageRow(*cur.fetchone())
def get_rows_by_binary_id(self, binary_id: int) -> AssemblageRow:
with self.conn:
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id, ))
row = cur.fetchone()
while row:
yield AssemblageRow(*row)
row = cur.fetchone()
def get_path_by_binary_id(self, binary_id: int) -> Path:
with self.conn:
cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id, ))
return self.samples / cur.fetchone()[0]
def get_pe_by_binary_id(self, binary_id: int) -> pefile.PE:
path = self.get_path_by_binary_id(binary_id)
return pefile.PE(data=path.read_bytes(), fast_load=True)
def main(argv=None):
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="Inspect object boundaries in compiled programs")
capa.main.install_common_args(parser, wanted={})
parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database")
parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory")
parser.add_argument("binary_id", type=int, help="primary key of binary to inspect")
args = parser.parse_args(args=argv)
try:
capa.main.handle_common_args(args)
except capa.main.ShouldExitError as e:
return e.status_code
if not args.assemblage_database.is_file():
raise ValueError("database doesn't exist")
db = Assemblage(args.assemblage_database, args.assemblage_directory)
# print(db.get_row_by_binary_id(args.binary_id))
# print(db.get_pe_by_binary_id(args.binary_id))
@dataclass
class Function:
file: str
name: str
start_rva: int
end_rva: int
functions = [
Function(
file=m.source_file,
name=m.function_name,
start_rva=m.start_rva,
end_rva=m.end_rva,
)
for m in db.get_rows_by_binary_id(args.binary_id)
]
import rich
import rich.table
print(db.get_path_by_binary_id(args.binary_id))
t = rich.table.Table()
t.add_column("rva")
t.add_column("filename")
t.add_column("name")
for function in sorted(functions, key=lambda f: f.start_rva):
t.add_row(hex(function.start_rva), function.file, function.name)
rich.print(t)
# db.conn.close()
if __name__ == "__main__":
sys.exit(main())