mirror of
https://github.com/mandiant/capa.git
synced 2025-12-30 14:40:53 -08:00
769 lines
26 KiB
Python
769 lines
26 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
import logging
|
|
import sqlite3
|
|
import argparse
|
|
import subprocess
|
|
from typing import Iterator, Optional, Literal
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from multiprocessing import Pool
|
|
|
|
import pefile
|
|
import lancelot
|
|
import networkx as nx
|
|
import lancelot.be2utils
|
|
from lancelot.be2utils import AddressSpace, BinExport2Index, ReadMemoryError
|
|
from lancelot.be2utils.binexport2_pb2 import BinExport2
|
|
|
|
import capa.main
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def is_vertex_type(vertex: BinExport2.CallGraph.Vertex, type_: BinExport2.CallGraph.Vertex.Type.ValueType) -> bool:
|
|
return vertex.HasField("type") and vertex.type == type_
|
|
|
|
|
|
def is_vertex_thunk(vertex: BinExport2.CallGraph.Vertex) -> bool:
|
|
return is_vertex_type(vertex, BinExport2.CallGraph.Vertex.Type.THUNK)
|
|
|
|
|
|
THUNK_CHAIN_DEPTH_DELTA = 5
|
|
|
|
|
|
def compute_thunks(be2: BinExport2, idx: BinExport2Index) -> dict[int, int]:
|
|
# from thunk address to target function address
|
|
thunks: dict[int, int] = {}
|
|
|
|
for addr, vertex_idx in idx.vertex_index_by_address.items():
|
|
vertex: BinExport2.CallGraph.Vertex = be2.call_graph.vertex[vertex_idx]
|
|
if not is_vertex_thunk(vertex):
|
|
continue
|
|
|
|
curr_vertex_idx: int = vertex_idx
|
|
for _ in range(THUNK_CHAIN_DEPTH_DELTA):
|
|
thunk_callees: list[int] = idx.callees_by_vertex_index[curr_vertex_idx]
|
|
# if this doesn't hold, then it doesn't seem like this is a thunk,
|
|
# because either, len is:
|
|
# 0 and the thunk doesn't point to anything, such as `jmp eax`, or
|
|
# >1 and the thunk may end up at many functions.
|
|
|
|
if not thunk_callees:
|
|
# maybe we have an indirect jump, like `jmp eax`
|
|
# that we can't actually resolve here.
|
|
break
|
|
|
|
if len(thunk_callees) != 1:
|
|
for thunk_callee in thunk_callees:
|
|
logger.warning("%s", hex(be2.call_graph.vertex[thunk_callee].address))
|
|
assert len(thunk_callees) == 1, f"thunk @ {hex(addr)} failed"
|
|
|
|
thunked_vertex_idx: int = thunk_callees[0]
|
|
thunked_vertex: BinExport2.CallGraph.Vertex = be2.call_graph.vertex[thunked_vertex_idx]
|
|
|
|
if not is_vertex_thunk(thunked_vertex):
|
|
assert thunked_vertex.HasField("address")
|
|
|
|
thunks[addr] = thunked_vertex.address
|
|
break
|
|
|
|
curr_vertex_idx = thunked_vertex_idx
|
|
|
|
return thunks
|
|
|
|
|
|
def read_string(address_space: AddressSpace, address: int) -> Optional[str]:
|
|
try:
|
|
# if at end of segment then there might be an overrun here.
|
|
buf: bytes = address_space.read_memory(address, 0x100)
|
|
|
|
except ReadMemoryError:
|
|
logger.debug("failed to read memory: 0x%x", address)
|
|
return None
|
|
|
|
# note: we *always* break after the first iteration
|
|
for s in capa.features.extractors.strings.extract_ascii_strings(buf):
|
|
if s.offset != 0:
|
|
break
|
|
|
|
return s.s
|
|
|
|
# note: we *always* break after the first iteration
|
|
for s in capa.features.extractors.strings.extract_unicode_strings(buf):
|
|
if s.offset != 0:
|
|
break
|
|
|
|
return s.s
|
|
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class AssemblageRow:
|
|
# from table: binaries
|
|
binary_id: int
|
|
file_name: str
|
|
platform: str
|
|
build_mode: str
|
|
toolset_version: str
|
|
github_url: str
|
|
optimization: str
|
|
repo_last_update: int
|
|
size: int
|
|
path: str
|
|
license: str
|
|
binary_hash: str
|
|
repo_commit_hash: str
|
|
# from table: functions
|
|
function_id: int
|
|
function_name: str
|
|
function_hash: str
|
|
top_comments: str
|
|
source_codes: str
|
|
prototype: str
|
|
_source_file: str
|
|
# from table: rvas
|
|
rva_id: int
|
|
start_rva: int
|
|
end_rva: int
|
|
|
|
@property
|
|
def source_file(self):
|
|
# cleanup some extra metadata provided by assemblage
|
|
return self._source_file.partition(" (MD5: ")[0].partition(" (0x3: ")[0]
|
|
|
|
|
|
class Assemblage:
|
|
conn: sqlite3.Connection
|
|
samples: Path
|
|
|
|
def __init__(self, db: Path, samples: Path):
|
|
super().__init__()
|
|
|
|
self.db = db
|
|
self.samples = samples
|
|
|
|
self.conn = sqlite3.connect(self.db)
|
|
with self.conn:
|
|
self.conn.executescript(
|
|
"""
|
|
PRAGMA journal_mode = WAL;
|
|
PRAGMA synchronous = NORMAL;
|
|
PRAGMA busy_timeout = 5000;
|
|
PRAGMA cache_size = -20000; -- 20MB
|
|
PRAGMA foreign_keys = true;
|
|
PRAGMA temp_store = memory;
|
|
|
|
BEGIN IMMEDIATE TRANSACTION;
|
|
CREATE INDEX IF NOT EXISTS idx__functions__binary_id ON functions (binary_id);
|
|
CREATE INDEX IF NOT EXISTS idx__rvas__function_id ON rvas (function_id);
|
|
|
|
CREATE VIEW IF NOT EXISTS assemblage AS
|
|
SELECT
|
|
binaries.id AS binary_id,
|
|
binaries.file_name AS file_name,
|
|
binaries.platform AS platform,
|
|
binaries.build_mode AS build_mode,
|
|
binaries.toolset_version AS toolset_version,
|
|
binaries.github_url AS github_url,
|
|
binaries.optimization AS optimization,
|
|
binaries.repo_last_update AS repo_last_update,
|
|
binaries.size AS size,
|
|
binaries.path AS path,
|
|
binaries.license AS license,
|
|
binaries.hash AS hash,
|
|
binaries.repo_commit_hash AS repo_commit_hash,
|
|
|
|
functions.id AS function_id,
|
|
functions.name AS function_name,
|
|
functions.hash AS function_hash,
|
|
functions.top_comments AS top_comments,
|
|
functions.source_codes AS source_codes,
|
|
functions.prototype AS prototype,
|
|
functions.source_file AS source_file,
|
|
|
|
rvas.id AS rva_id,
|
|
rvas.start AS start_rva,
|
|
rvas.end AS end_rva
|
|
FROM binaries
|
|
JOIN functions ON binaries.id = functions.binary_id
|
|
JOIN rvas ON functions.id = rvas.function_id;
|
|
"""
|
|
)
|
|
|
|
def get_row_by_binary_id(self, binary_id: int) -> AssemblageRow:
|
|
with self.conn:
|
|
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id,))
|
|
return AssemblageRow(*cur.fetchone())
|
|
|
|
def get_rows_by_binary_id(self, binary_id: int) -> Iterator[AssemblageRow]:
|
|
with self.conn:
|
|
cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id,))
|
|
row = cur.fetchone()
|
|
while row:
|
|
yield AssemblageRow(*row)
|
|
row = cur.fetchone()
|
|
|
|
def get_path_by_binary_id(self, binary_id: int) -> Path:
|
|
with self.conn:
|
|
cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id,))
|
|
return self.samples / cur.fetchone()[0]
|
|
|
|
def get_pe_by_binary_id(self, binary_id: int) -> pefile.PE:
|
|
path = self.get_path_by_binary_id(binary_id)
|
|
return pefile.PE(data=path.read_bytes(), fast_load=True)
|
|
|
|
def get_binary_ids(self) -> Iterator[int]:
|
|
with self.conn:
|
|
cur = self.conn.execute("SELECT DISTINCT binary_id FROM assemblage ORDER BY binary_id ASC;")
|
|
row = cur.fetchone()
|
|
while row:
|
|
yield row[0]
|
|
row = cur.fetchone()
|
|
|
|
|
|
def generate_main(args: argparse.Namespace) -> int:
|
|
if not args.assemblage_database.is_file():
|
|
raise ValueError("database doesn't exist")
|
|
|
|
db = Assemblage(args.assemblage_database, args.assemblage_directory)
|
|
|
|
pe = db.get_pe_by_binary_id(args.binary_id)
|
|
base_address: int = pe.OPTIONAL_HEADER.ImageBase
|
|
|
|
functions_by_address = {
|
|
base_address + function.start_rva: function for function in db.get_rows_by_binary_id(args.binary_id)
|
|
}
|
|
|
|
hash = db.get_row_by_binary_id(args.binary_id).binary_hash
|
|
|
|
def make_node_id(address: int) -> str:
|
|
return f"{hash}:{address:x}"
|
|
|
|
pe_path = db.get_path_by_binary_id(args.binary_id)
|
|
be2: BinExport2 = lancelot.get_binexport2_from_bytes(
|
|
pe_path.read_bytes(), function_hints=list(functions_by_address.keys())
|
|
)
|
|
|
|
idx = lancelot.be2utils.BinExport2Index(be2)
|
|
address_space = lancelot.be2utils.AddressSpace.from_pe(pe, base_address)
|
|
thunks = compute_thunks(be2, idx)
|
|
|
|
g = nx.MultiDiGraph()
|
|
|
|
# ensure all functions from ground truth have an entry
|
|
for address, function in functions_by_address.items():
|
|
g.add_node(
|
|
make_node_id(address),
|
|
address=address,
|
|
type="function",
|
|
)
|
|
|
|
for flow_graph in be2.flow_graph:
|
|
datas: set[int] = set()
|
|
callees: set[int] = set()
|
|
|
|
entry_basic_block_index: int = flow_graph.entry_basic_block_index
|
|
flow_graph_address: int = idx.get_basic_block_address(entry_basic_block_index)
|
|
|
|
for basic_block_index in flow_graph.basic_block_index:
|
|
basic_block: BinExport2.BasicBlock = be2.basic_block[basic_block_index]
|
|
|
|
for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block):
|
|
for addr in instruction.call_target:
|
|
addr = thunks.get(addr, addr)
|
|
|
|
if addr not in idx.vertex_index_by_address:
|
|
# disassembler did not define function at address
|
|
logger.debug("0x%x is not a vertex", addr)
|
|
continue
|
|
|
|
vertex_idx: int = idx.vertex_index_by_address[addr]
|
|
vertex: BinExport2.CallGraph.Vertex = be2.call_graph.vertex[vertex_idx]
|
|
|
|
callees.add(vertex.address)
|
|
|
|
for data_reference_index in idx.data_reference_index_by_source_instruction_index.get(
|
|
instruction_index, []
|
|
):
|
|
data_reference: BinExport2.DataReference = be2.data_reference[data_reference_index]
|
|
data_reference_address: int = data_reference.address
|
|
|
|
if data_reference_address in idx.insn_address_by_index:
|
|
# appears to be code
|
|
continue
|
|
|
|
datas.add(data_reference_address)
|
|
|
|
vertex_index = idx.vertex_index_by_address[flow_graph_address]
|
|
name = idx.get_function_name_by_vertex(vertex_index)
|
|
|
|
g.add_node(
|
|
make_node_id(flow_graph_address),
|
|
address=flow_graph_address,
|
|
type="function",
|
|
)
|
|
if datas or callees:
|
|
logger.info("%s @ 0x%X:", name, flow_graph_address)
|
|
|
|
for data_address in sorted(datas):
|
|
logger.info(" - 0x%X", data_address)
|
|
g.add_node(
|
|
make_node_id(data_address),
|
|
address=data_address,
|
|
type="data",
|
|
)
|
|
g.add_edge(
|
|
make_node_id(flow_graph_address),
|
|
make_node_id(data_address),
|
|
key="reference",
|
|
)
|
|
|
|
for callee in sorted(callees):
|
|
logger.info(" - %s", idx.get_function_name_by_address(callee))
|
|
|
|
g.add_node(
|
|
make_node_id(callee),
|
|
address=callee,
|
|
type="function",
|
|
)
|
|
g.add_edge(
|
|
make_node_id(flow_graph_address),
|
|
make_node_id(callee),
|
|
key="call",
|
|
)
|
|
|
|
else:
|
|
logger.info("%s @ 0x%X: (none)", name, flow_graph_address)
|
|
|
|
# set ground truth node attributes from source data
|
|
for node, attrs in g.nodes(data=True):
|
|
if attrs["type"] != "function":
|
|
continue
|
|
|
|
if f := functions_by_address.get(attrs["address"]):
|
|
attrs["name"] = f.function_name
|
|
attrs["file"] = f.file_name
|
|
|
|
for section in pe.sections:
|
|
# Within each section, emit a neighbor edge for each pair of neighbors.
|
|
# Neighbors only link nodes of the same type, because assemblage doesn't
|
|
# have ground truth for data items, so we don't quite know where to split.
|
|
# Consider this situation:
|
|
#
|
|
# moduleA::func1
|
|
# --- cut ---
|
|
# moduleB::func1
|
|
#
|
|
# that one is ok, but this is hard:
|
|
#
|
|
# moduleA::func1
|
|
# --- cut??? ---
|
|
# dataZ
|
|
# --- or cut here??? ---
|
|
# moduleB::func1
|
|
#
|
|
# Does the cut go before or after dataZ?
|
|
# So, we only have neighbor graphs within functions, and within datas.
|
|
# For datas, we don't allow interspersed functions.
|
|
|
|
section_nodes = sorted(
|
|
[
|
|
(node, attrs)
|
|
for node, attrs in g.nodes(data=True)
|
|
if (section.VirtualAddress + base_address)
|
|
<= attrs["address"]
|
|
< (base_address + section.VirtualAddress + section.Misc_VirtualSize)
|
|
],
|
|
key=lambda p: p[1]["address"],
|
|
)
|
|
|
|
# add neighbor edges between data items.
|
|
# the data items must not be separated by any functions.
|
|
for i in range(1, len(section_nodes)):
|
|
a, a_attrs = section_nodes[i - 1]
|
|
b, b_attrs = section_nodes[i]
|
|
|
|
if a_attrs["type"] != "data":
|
|
continue
|
|
|
|
if b_attrs["type"] != "data":
|
|
continue
|
|
|
|
g.add_edge(a, b, key="neighbor")
|
|
g.add_edge(b, a, key="neighbor")
|
|
|
|
section_functions = [
|
|
(node, attrs)
|
|
for node, attrs in section_nodes
|
|
if attrs["type"] == "function"
|
|
# we only have ground truth for the known functions
|
|
# so only consider those in the function neighbor graph.
|
|
and attrs["address"] in functions_by_address
|
|
]
|
|
|
|
# add neighbor edges between functions.
|
|
# we drop the potentially interspersed data items before computing these edges.
|
|
for i in range(1, len(section_functions)):
|
|
a, a_attrs = section_functions[i - 1]
|
|
b, b_attrs = section_functions[i]
|
|
is_boundary = a_attrs["file"] == b_attrs["file"]
|
|
|
|
# edge attribute: is_source_file_boundary
|
|
g.add_edge(a, b, key="neighbor", is_source_file_boundary=is_boundary)
|
|
g.add_edge(b, a, key="neighbor", is_source_file_boundary=is_boundary)
|
|
|
|
# rename unknown functions like: sub_401000
|
|
for n, attrs in g.nodes(data=True):
|
|
if attrs["type"] != "function":
|
|
continue
|
|
|
|
if "name" in attrs:
|
|
continue
|
|
|
|
attrs["name"] = f"sub_{attrs['address']:x}"
|
|
|
|
# assign human-readable repr to add nodes
|
|
# assign is_import=bool to functions
|
|
# assign is_string=bool to datas
|
|
for n, attrs in g.nodes(data=True):
|
|
match attrs["type"]:
|
|
case "function":
|
|
attrs["repr"] = attrs["name"]
|
|
attrs["is_import"] = "!" in attrs["name"]
|
|
case "data":
|
|
if string := read_string(address_space, attrs["address"]):
|
|
attrs["repr"] = json.dumps(string)
|
|
attrs["is_string"] = True
|
|
else:
|
|
attrs["repr"] = f"data_{attrs['address']:x}"
|
|
attrs["is_string"] = False
|
|
|
|
for line in nx.generate_gexf(g):
|
|
print(line)
|
|
|
|
# db.conn.close()
|
|
return 0
|
|
|
|
|
|
def _worker(args):
|
|
|
|
assemblage_database: Path
|
|
assemblage_directory: Path
|
|
graph_file: Path
|
|
binary_id: int
|
|
|
|
(assemblage_database, assemblage_directory, graph_file, binary_id) = args
|
|
if graph_file.is_file():
|
|
return
|
|
|
|
logger.info("processing: %d", binary_id)
|
|
process = subprocess.run(
|
|
["python", __file__, "--debug", "generate", assemblage_database, assemblage_directory, str(binary_id)],
|
|
capture_output=True,
|
|
encoding="utf-8",
|
|
)
|
|
if process.returncode != 0:
|
|
logger.warning("failed: %d", binary_id)
|
|
logger.debug("%s", process.stderr)
|
|
return
|
|
|
|
graph_file.parent.mkdir(exist_ok=True)
|
|
graph = process.stdout
|
|
graph_file.write_text(graph)
|
|
|
|
|
|
def generate_all_main(args: argparse.Namespace) -> int:
|
|
if not args.assemblage_database.is_file():
|
|
raise ValueError("database doesn't exist")
|
|
|
|
db = Assemblage(args.assemblage_database, args.assemblage_directory)
|
|
|
|
binary_ids = list(db.get_binary_ids())
|
|
|
|
with Pool(args.num_workers) as p:
|
|
_ = list(
|
|
p.imap_unordered(
|
|
_worker,
|
|
(
|
|
(
|
|
args.assemblage_database,
|
|
args.assemblage_directory,
|
|
args.output_directory / str(binary_id) / "graph.gexf",
|
|
binary_id,
|
|
)
|
|
for binary_id in binary_ids
|
|
),
|
|
)
|
|
)
|
|
|
|
return 0
|
|
|
|
|
|
def cluster_main(args: argparse.Namespace) -> int:
|
|
if not args.graph.is_file():
|
|
raise ValueError("graph file doesn't exist")
|
|
|
|
g = nx.read_gexf(args.graph)
|
|
|
|
communities = nx.algorithms.community.louvain_communities(g)
|
|
for i, community in enumerate(communities):
|
|
print(f"[{i}]:")
|
|
for node in community:
|
|
if "name" in g.nodes[node]:
|
|
print(f" - {hex(int(node, 0))}: {g.nodes[node]['file']}")
|
|
else:
|
|
print(f" - {hex(int(node, 0))}")
|
|
|
|
return 0
|
|
|
|
|
|
|
|
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
# uv pip install torch-geometric pandas numpy
|
|
# import torch # do this on-demand below, because its slow
|
|
# from torch_geometric.data import HeteroData
|
|
|
|
|
|
@dataclass
|
|
class NodeType:
|
|
type: str
|
|
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
|
|
|
|
|
|
@dataclass
|
|
class EdgeType:
|
|
key: str
|
|
source_type: NodeType
|
|
destination_type: NodeType
|
|
attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float]
|
|
|
|
|
|
NODE_TYPES = {
|
|
node.type: node
|
|
for node in [
|
|
NodeType(
|
|
type="function",
|
|
attributes={
|
|
"is_import": False,
|
|
# unused:
|
|
# - repr: str
|
|
# - address: int
|
|
# - name: str
|
|
# - file: str
|
|
},
|
|
),
|
|
NodeType(
|
|
type="data",
|
|
attributes={
|
|
"is_string": False,
|
|
# unused:
|
|
# - repr: str
|
|
# - address: int
|
|
},
|
|
),
|
|
]
|
|
}
|
|
|
|
FUNCTION_NODE = NODE_TYPES["function"]
|
|
DATA_NODE = NODE_TYPES["data"]
|
|
|
|
EDGE_TYPES = {
|
|
(edge.source_type.type, edge.key, edge.destination_type.type): edge
|
|
for edge in [
|
|
EdgeType(
|
|
key="call",
|
|
source_type=FUNCTION_NODE,
|
|
destination_type=FUNCTION_NODE,
|
|
attributes={},
|
|
),
|
|
EdgeType(
|
|
key="reference",
|
|
source_type=FUNCTION_NODE,
|
|
destination_type=DATA_NODE,
|
|
attributes={},
|
|
),
|
|
EdgeType(
|
|
key="neighbor",
|
|
source_type=FUNCTION_NODE,
|
|
destination_type=FUNCTION_NODE,
|
|
attributes={
|
|
# this is the attribute to predict
|
|
"is_source_file_boundary": False,
|
|
},
|
|
),
|
|
EdgeType(
|
|
key="neighbor",
|
|
source_type=DATA_NODE,
|
|
destination_type=DATA_NODE,
|
|
attributes={
|
|
# this is the attribute to predict
|
|
"is_source_file_boundary": False,
|
|
},
|
|
),
|
|
]
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class LoadedGraph:
|
|
data: "HeteroData"
|
|
# map from node id (str) to node index (int), and node index (int) to node id (str).
|
|
mapping: dict[str | int, int | str]
|
|
|
|
|
|
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
|
import torch
|
|
from torch_geometric.data import HeteroData
|
|
|
|
# Our networkx graph identifies nodes by str ("sha256:address").
|
|
# Torch identifies nodes by index, from 0 to #nodes.
|
|
# Map one to another.
|
|
node_indexes_by_node: dict[str, int] = {}
|
|
nodes_by_node_index: dict[int, str] = {}
|
|
# Because the types are different (str and int),
|
|
# here's a single mapping where the type of the key implies
|
|
# the sort of lookup you're doing (by index (int) or by node id (str)).
|
|
node_mapping: dict[str | int, int | str] = {}
|
|
for i, node in enumerate(sorted(g.nodes)):
|
|
node_indexes_by_node[node] = i
|
|
nodes_by_node_index[i] = node
|
|
node_mapping[node] = i
|
|
node_mapping[i] = node
|
|
|
|
data = HeteroData()
|
|
|
|
for node_type in NODE_TYPES.values():
|
|
logger.debug("loading nodes: %s", node_type.type)
|
|
|
|
node_indexes: list[int] = []
|
|
attr_values: dict[str, list] = {attribute: [] for attribute in node_type.attributes.keys()}
|
|
|
|
for node, attrs in g.nodes(data=True):
|
|
if attrs["type"] != node_type.type:
|
|
continue
|
|
|
|
node_index = node_indexes_by_node[node]
|
|
node_indexes.append(node_index)
|
|
|
|
for attribute, default_value in node_type.attributes.items():
|
|
value = attrs.get(attribute, default_value)
|
|
attr_values[attribute].append(value)
|
|
|
|
data[node_type.type].node_id = torch.tensor(node_indexes)
|
|
if attr_values:
|
|
# attribute order is implicit in the NODE_TYPES data model above.
|
|
data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float()
|
|
|
|
for edge_type in EDGE_TYPES.values():
|
|
logger.debug(
|
|
"loading edges: %s > %s > %s",
|
|
edge_type.source_type.type, edge_type.key, edge_type.destination_type.type
|
|
)
|
|
|
|
source_indexes: list[int] = []
|
|
destination_indexes: list[int] = []
|
|
attr_values: dict[str, list] = {attribute: [] for attribute in edge_type.attributes.keys()}
|
|
|
|
for source, destination, key, attrs in g.edges(data=True, keys=True):
|
|
if key != edge_type.key:
|
|
continue
|
|
if g.nodes[source]["type"] != edge_type.source_type.type:
|
|
continue
|
|
if g.nodes[destination]["type"] != edge_type.destination_type.type:
|
|
continue
|
|
|
|
source_index = node_indexes_by_node[source]
|
|
destination_index = node_indexes_by_node[destination]
|
|
|
|
source_indexes.append(source_index)
|
|
destination_indexes.append(destination_index)
|
|
|
|
for attribute, default_value in edge_type.attributes.items():
|
|
value = attrs.get(attribute, default_value)
|
|
attr_values[attribute].append(value)
|
|
|
|
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack(
|
|
[
|
|
torch.tensor(source_indexes),
|
|
torch.tensor(destination_indexes),
|
|
]
|
|
)
|
|
if attr_values:
|
|
# attribute order is implicit in the EDGE_TYPES data model above.
|
|
data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack(
|
|
[torch.tensor(values) for values in attr_values.values()], dim=-1
|
|
).float()
|
|
|
|
return LoadedGraph(
|
|
data,
|
|
node_mapping,
|
|
)
|
|
|
|
|
|
def train_main(args: argparse.Namespace) -> int:
|
|
if not args.graph.is_file():
|
|
raise ValueError("graph file doesn't exist")
|
|
|
|
logger.debug("loading torch")
|
|
import torch
|
|
|
|
g = nx.read_gexf(args.graph)
|
|
|
|
lg = load_graph(g)
|
|
|
|
print(lg.data)
|
|
|
|
return 0
|
|
|
|
|
|
def main(argv=None) -> int:
|
|
if argv is None:
|
|
argv = sys.argv[1:]
|
|
|
|
parser = argparse.ArgumentParser(description="Identify object boundaries in compiled programs")
|
|
capa.main.install_common_args(parser, wanted={})
|
|
subparsers = parser.add_subparsers(title="subcommands", required=True)
|
|
|
|
generate_parser = subparsers.add_parser("generate", help="generate graph for a sample")
|
|
generate_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database")
|
|
generate_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory")
|
|
generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect")
|
|
generate_parser.set_defaults(func=generate_main)
|
|
|
|
num_cores = os.cpu_count() or 1
|
|
default_workers = max(1, num_cores - 2)
|
|
generate_all_parser = subparsers.add_parser("generate_all", help="generate graphs for all samples")
|
|
generate_all_parser.add_argument("assemblage_database", type=Path, help="path to Assemblage database")
|
|
generate_all_parser.add_argument("assemblage_directory", type=Path, help="path to Assemblage samples directory")
|
|
generate_all_parser.add_argument("output_directory", type=Path, help="path to output directory")
|
|
generate_all_parser.add_argument(
|
|
"--num_workers", type=int, default=default_workers, help="number of workers to use"
|
|
)
|
|
generate_all_parser.set_defaults(func=generate_all_main)
|
|
|
|
cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph")
|
|
cluster_parser.add_argument("graph", type=Path, help="path to a graph file")
|
|
cluster_parser.set_defaults(func=cluster_main)
|
|
|
|
train_parser = subparsers.add_parser("train", help="train using an existing graph")
|
|
train_parser.add_argument("graph", type=Path, help="path to a graph file")
|
|
train_parser.set_defaults(func=train_main)
|
|
|
|
args = parser.parse_args(args=argv)
|
|
|
|
try:
|
|
capa.main.handle_common_args(args)
|
|
except capa.main.ShouldExitError as e:
|
|
return e.status_code
|
|
|
|
logging.getLogger("goblin.pe").setLevel(logging.WARNING)
|
|
|
|
return args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|