Compare commits

...

3 Commits

View File

@@ -11,8 +11,9 @@ import logging
import argparse import argparse
import tempfile import tempfile
import contextlib import contextlib
import collections
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Iterable, Optional
from pathlib import Path from pathlib import Path
import rich import rich
@@ -51,6 +52,7 @@ class Method(str, Enum):
STRINGS = "strings" STRINGS = "strings"
THUNK = "thunk" THUNK = "thunk"
ENTRYPOINT = "entrypoint" ENTRYPOINT = "entrypoint"
CALLGRAPH = "callgraph"
class FunctionClassification(BaseModel): class FunctionClassification(BaseModel):
@@ -69,9 +71,19 @@ class FunctionClassification(BaseModel):
library_name: Optional[str] = None library_name: Optional[str] = None
library_version: Optional[str] = None library_version: Optional[str] = None
# additional note on the classification, TODO removeme if not useful beyond dev/debug
note: Optional[str] = None
class BinaryLayout(BaseModel):
va: int
# size of the function chunks in bytes
size: int
class FunctionIdResults(BaseModel): class FunctionIdResults(BaseModel):
function_classifications: List[FunctionClassification] function_classifications: List[FunctionClassification]
layout: List[BinaryLayout]
@contextlib.contextmanager @contextlib.contextmanager
@@ -110,11 +122,63 @@ def ida_session(input_path: Path, use_temp_dir=True):
t.unlink() t.unlink()
def get_library_called_functions(
function_classifications: list[FunctionClassification],
) -> Iterable[FunctionClassification]:
MAX_PASSES = 10
classifications_by_va = capa.analysis.strings.create_index(function_classifications, "va")
for n in range(MAX_PASSES):
found_new_lib_func = False
for fva in idautils.Functions():
if classifications_by_va.get(fva):
# already classified
continue
for ref in idautils.CodeRefsTo(fva, True):
f: idaapi.func_t = idaapi.get_func(ref)
if not f:
# no function associated with reference location
continue
ref_fva = f.start_ea
fname = idaapi.get_func_name(ref_fva)
if fname in ("___tmainCRTStartup",):
# ignore library functions, where we know that they call user-code
# TODO(mr): extend this list
continue
if classifications := classifications_by_va.get(ref_fva):
for c in classifications:
if c.classification == Classification.LIBRARY:
fc = FunctionClassification(
va=fva,
name=idaapi.get_func_name(fva),
classification=Classification.LIBRARY,
method=Method.CALLGRAPH,
note=f"called by 0x{ref_fva:x} ({c.method.value}{f', {c.library_name}@{c.library_version})' if c.library_name else ')'}",
)
classifications_by_va[fva].append(fc)
yield fc
found_new_lib_func = True
break
if not found_new_lib_func:
logger.debug("no update in pass %d, done here", n)
return
def is_thunk_function(fva): def is_thunk_function(fva):
f = idaapi.get_func(fva) f = idaapi.get_func(fva)
return bool(f.flags & idaapi.FUNC_THUNK) return bool(f.flags & idaapi.FUNC_THUNK)
def get_function_size(fva):
f = idaapi.get_func(fva)
assert f.start_ea == fva
return sum([end_ea - start_ea for (start_ea, end_ea) in idautils.Chunks(fva)])
def main(argv=None): def main(argv=None):
if argv is None: if argv is None:
argv = sys.argv[1:] argv = sys.argv[1:]
@@ -177,7 +241,11 @@ def main(argv=None):
for va in idautils.Functions(): for va in idautils.Functions():
name = idaapi.get_func_name(va) name = idaapi.get_func_name(va)
if name not in {"WinMain", }: if name not in {
"WinMain",
"_main",
"main",
}:
continue continue
function_classifications.append( function_classifications.append(
@@ -189,7 +257,11 @@ def main(argv=None):
) )
) )
doc = FunctionIdResults(function_classifications=[]) with capa.main.timing("call graph based library identification"):
for fc in get_library_called_functions(function_classifications):
function_classifications.append(fc)
doc = FunctionIdResults(function_classifications=[], layout=[])
classifications_by_va = capa.analysis.strings.create_index(function_classifications, "va") classifications_by_va = capa.analysis.strings.create_index(function_classifications, "va")
for va in idautils.Functions(): for va in idautils.Functions():
if classifications := classifications_by_va.get(va): if classifications := classifications_by_va.get(va):
@@ -203,20 +275,31 @@ def main(argv=None):
method=None, method=None,
) )
) )
doc.layout.append(
BinaryLayout(
va=va,
size=get_function_size(va),
)
)
if args.json: if args.json:
print(doc.model_dump_json()) # noqa: T201 print found print(doc.model_dump_json()) # noqa: T201 print found
else: else:
table = rich.table.Table() table = rich.table.Table(
table.add_column("FVA") "FVA",
table.add_column("CLASSIFICATION") "CLASSIFICATION",
table.add_column("METHOD") "METHOD",
table.add_column("FNAME") "FNAME",
table.add_column("EXTRA INFO") "EXTRA",
"SIZE"
)
classifications_by_va = capa.analysis.strings.create_index(doc.function_classifications, "va", sorted_=True) classifications_by_va = capa.analysis.strings.create_index(doc.function_classifications, "va", sorted_=True)
size_by_va = {layout.va: layout.size for layout in doc.layout}
size_by_classification = collections.defaultdict(int)
for va, classifications in classifications_by_va.items(): for va, classifications in classifications_by_va.items():
# TODO count of classifications if multiple?
name = ", ".join({c.name for c in classifications}) name = ", ".join({c.name for c in classifications})
if "sub_" in name: if "sub_" in name:
name = Text(name, style="grey53") name = Text(name, style="grey53")
@@ -224,17 +307,29 @@ def main(argv=None):
classification = {c.classification for c in classifications} classification = {c.classification for c in classifications}
method = {c.method for c in classifications if c.method} method = {c.method for c in classifications if c.method}
extra = {f"{c.library_name}@{c.library_version}" for c in classifications if c.library_name} extra = {f"{c.library_name}@{c.library_version}" for c in classifications if c.library_name}
note = {f"{c.note}" for c in classifications if c.note}
table.add_row( table.add_row(
hex(va), hex(va),
", ".join(classification) if classification != {"unknown"} else Text("unknown", style="grey53"), ", ".join(classification) if classification != {"unknown"} else Text("unknown", style="grey53"),
", ".join(method), ", ".join(method),
name, name,
", ".join(extra), f"{', '.join(extra)} {', '.join(note)}",
f"{size_by_va[va]}",
) )
size_by_classification["-".join(classification)] += size_by_va[va]
rich.print(table) rich.print(table)
stats_table = rich.table.Table(
"ID", rich.table.Column("SIZE", justify="right"), rich.table.Column("%", justify="right")
)
size_all = sum(size_by_classification.values())
for k, s in size_by_classification.items():
stats_table.add_row(k, f"{s:d}", f"{100 * s / size_all:.2f}")
rich.print(stats_table)
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) sys.exit(main())