mirror of
https://github.com/mandiant/capa.git
synced 2025-12-05 20:40:05 -08:00
wip
This commit is contained in:
@@ -310,6 +310,7 @@ def generate_main(args: argparse.Namespace) -> int:
|
||||
|
||||
for data_address in sorted(datas):
|
||||
logger.info(" - 0x%X", data_address)
|
||||
# TODO: check if this is already a function
|
||||
g.add_node(
|
||||
make_node_id(data_address),
|
||||
address=data_address,
|
||||
@@ -520,9 +521,8 @@ def cluster_main(args: argparse.Namespace) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
# uv pip install torch-geometric pandas numpy
|
||||
# uv pip install torch-geometric pandas numpy scikit-learn
|
||||
# import torch # do this on-demand below, because its slow
|
||||
# from torch_geometric.data import HeteroData
|
||||
|
||||
@@ -548,6 +548,8 @@ NODE_TYPES = {
|
||||
type="function",
|
||||
attributes={
|
||||
"is_import": False,
|
||||
"does_reference_string": False,
|
||||
# "ground_truth": False,
|
||||
# unused:
|
||||
# - repr: str
|
||||
# - address: int
|
||||
@@ -585,22 +587,37 @@ EDGE_TYPES = {
|
||||
destination_type=DATA_NODE,
|
||||
attributes={},
|
||||
),
|
||||
EdgeType(
|
||||
# When functions reference other functions as data,
|
||||
# such as passing a function pointer as a callback.
|
||||
#
|
||||
# Example:
|
||||
# __scrt_set_unhandled_exception_filter > reference > __scrt_unhandled_exception_filter
|
||||
key="reference",
|
||||
source_type=FUNCTION_NODE,
|
||||
destination_type=FUNCTION_NODE,
|
||||
attributes={},
|
||||
),
|
||||
EdgeType(
|
||||
key="neighbor",
|
||||
source_type=FUNCTION_NODE,
|
||||
destination_type=FUNCTION_NODE,
|
||||
attributes={
|
||||
# this is the attribute to predict
|
||||
"is_source_file_boundary": False,
|
||||
# this is the attribute to predict (ultimately)
|
||||
# "is_source_file_boundary": False,
|
||||
"distance": 1,
|
||||
},
|
||||
),
|
||||
EdgeType(
|
||||
key="neighbor",
|
||||
source_type=DATA_NODE,
|
||||
destination_type=DATA_NODE,
|
||||
# attributes={
|
||||
# },
|
||||
attributes={
|
||||
# this is the attribute to predict
|
||||
"is_source_file_boundary": False,
|
||||
# this is the attribute to predict (ultimately)
|
||||
# "is_source_file_boundary": False,
|
||||
"distance": 1,
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -610,8 +627,10 @@ EDGE_TYPES = {
|
||||
@dataclass
|
||||
class LoadedGraph:
|
||||
data: "HeteroData"
|
||||
|
||||
# map from node type to:
|
||||
# map from node id (str) to node index (int), and node index (int) to node id (str).
|
||||
mapping: dict[str | int, int | str]
|
||||
mapping: dict[str, dict[str | int, int | str]]
|
||||
|
||||
|
||||
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
@@ -619,19 +638,26 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
from torch_geometric.data import HeteroData
|
||||
|
||||
# Our networkx graph identifies nodes by str ("sha256:address").
|
||||
# Torch identifies nodes by index, from 0 to #nodes.
|
||||
# Torch identifies nodes by index, from 0 to #nodes, for each type of node.
|
||||
# Map one to another.
|
||||
node_indexes_by_node: dict[str, int] = {}
|
||||
nodes_by_node_index: dict[int, str] = {}
|
||||
node_indexes_by_node: dict[str, dict[str, int]] = {n: {} for n in NODE_TYPES.keys()}
|
||||
# Because the types are different (str and int),
|
||||
# here's a single mapping where the type of the key implies
|
||||
# the sort of lookup you're doing (by index (int) or by node id (str)).
|
||||
node_mapping: dict[str | int, int | str] = {}
|
||||
for i, node in enumerate(sorted(g.nodes)):
|
||||
node_indexes_by_node[node] = i
|
||||
nodes_by_node_index[i] = node
|
||||
node_mapping[node] = i
|
||||
node_mapping[i] = node
|
||||
node_mapping: dict[str, dict[str | int, int | str]] = {n: {} for n in NODE_TYPES.keys()}
|
||||
for node_type in NODE_TYPES.keys():
|
||||
def is_this_node_type(node_attrs):
|
||||
node, attrs = node_attrs
|
||||
return attrs["type"] == node_type
|
||||
|
||||
ns = g.nodes(data=True)
|
||||
ns = sorted(ns)
|
||||
ns = filter(is_this_node_type, ns)
|
||||
ns = map(lambda p: p[0], ns)
|
||||
for i, node in enumerate(ns):
|
||||
node_indexes_by_node[node_type][node] = i
|
||||
node_mapping[node_type][node] = i
|
||||
node_mapping[node_type][i] = node
|
||||
|
||||
data = HeteroData()
|
||||
|
||||
@@ -645,7 +671,7 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
if attrs["type"] != node_type.type:
|
||||
continue
|
||||
|
||||
node_index = node_indexes_by_node[node]
|
||||
node_index = node_indexes_by_node[node_type.type][node]
|
||||
node_indexes.append(node_index)
|
||||
|
||||
for attribute, default_value in node_type.attributes.items():
|
||||
@@ -675,8 +701,11 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
|
||||
if g.nodes[destination]["type"] != edge_type.destination_type.type:
|
||||
continue
|
||||
|
||||
source_index = node_indexes_by_node[source]
|
||||
destination_index = node_indexes_by_node[destination]
|
||||
# These are global node indexes
|
||||
# but we need to provide the node type-local index.
|
||||
# That is, functions have their own node indexes, 0 to N. data have their own node indexes, 0 to N.
|
||||
source_index = node_indexes_by_node[g.nodes[source]["type"]][source]
|
||||
destination_index = node_indexes_by_node[g.nodes[destination]["type"]][destination]
|
||||
|
||||
source_indexes.append(source_index)
|
||||
destination_indexes.append(destination_index)
|
||||
@@ -710,11 +739,184 @@ def train_main(args: argparse.Namespace) -> int:
|
||||
logger.debug("loading torch")
|
||||
import torch
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
logger.debug("reading graph from disk")
|
||||
g = nx.read_gexf(args.graph)
|
||||
|
||||
lg = load_graph(g)
|
||||
# Initial model: learn to find functions that reference a string.
|
||||
#
|
||||
# Once this works, then we can try a more complex model (edge features),
|
||||
# and ultimately an edge classifier.
|
||||
#
|
||||
# Ground truth from existing patterns like:
|
||||
#
|
||||
# function > references > data (:is_string=True)
|
||||
|
||||
print(lg.data)
|
||||
for a, b, key, attrs in g.edges(data=True, keys=True):
|
||||
match (g.nodes[a]["type"], key, g.nodes[b]["type"]):
|
||||
case ("function", "reference", "data"):
|
||||
|
||||
if g.nodes[b].get("is_string"):
|
||||
g.nodes[a]["does_reference_string"] = True
|
||||
logger.debug("%s > reference > %s (string)", g.nodes[a]["repr"], g.nodes[b]["repr"])
|
||||
|
||||
case ("function", "reference", "function"):
|
||||
# The data model supports this.
|
||||
# Like passing a function pointer as a callback
|
||||
continue
|
||||
case ("data", "reference", "data"):
|
||||
# We don't support this.
|
||||
continue
|
||||
case ("data", "reference", "function"):
|
||||
# We don't support this.
|
||||
continue
|
||||
case (_, "call", _):
|
||||
continue
|
||||
case (_, "neighbor", _):
|
||||
continue
|
||||
case _:
|
||||
print(a, b, key, attrs, g.nodes[a], g.nodes[b])
|
||||
raise ValueError("unexpected structure")
|
||||
|
||||
# map existing attributes to the ground_truth attribute
|
||||
# for ease of updating the model/training.
|
||||
for node, attrs in g.nodes(data=True):
|
||||
if attrs["type"] != "function":
|
||||
continue
|
||||
|
||||
attrs["ground_truth"] = attrs.get("does_reference_string", False)
|
||||
|
||||
logger.debug("loading graph into torch")
|
||||
lg = load_graph(g)
|
||||
data = lg.data
|
||||
|
||||
data['data'].y = torch.zeros(data['data'].num_nodes, dtype=torch.long)
|
||||
data['function'].y = torch.zeros(data['function'].num_nodes, dtype=torch.long)
|
||||
true_indices = []
|
||||
|
||||
for node, attrs in g.nodes(data=True):
|
||||
if attrs.get("ground_truth"):
|
||||
print("true: ", g.nodes[node]["repr"])
|
||||
node_index = lg.mapping[attrs["type"]][node]
|
||||
print("index", attrs["type"], node_index)
|
||||
print(" ", node)
|
||||
print(" ", lg.mapping[attrs["type"]][node_index])
|
||||
|
||||
true_indices.append(node_index)
|
||||
# true_indices.append(data['function'].node_id[node_index].item())
|
||||
# print("true index: ", node_index, data['function'].node_id[node_index].item())
|
||||
|
||||
data['function'].y[true_indices] = 1
|
||||
print(data['function'].y)
|
||||
|
||||
# TODO
|
||||
import torch_geometric.transforms as T
|
||||
data = T.ToUndirected()(data)
|
||||
# data = T.AddSelfLoops()(data)
|
||||
data = T.NormalizeFeatures()(data)
|
||||
|
||||
print(data)
|
||||
|
||||
from torch_geometric.nn import RGCNConv, to_hetero, SAGEConv, Linear
|
||||
import torch.nn.functional as F
|
||||
|
||||
class GNN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv1 = SAGEConv((-1, -1), hidden_channels)
|
||||
self.conv2 = SAGEConv((-1, -1), hidden_channels)
|
||||
self.lin = Linear(hidden_channels, out_channels)
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
x = self.conv1(x, edge_index).relu()
|
||||
x = self.conv2(x, edge_index)
|
||||
x = self.lin(x)
|
||||
return x
|
||||
|
||||
model = GNN(hidden_channels=4, out_channels=2)
|
||||
# metadata: tuple[list of node types, list of edge types (source, key, dest)]
|
||||
model = to_hetero(model, data.metadata(), aggr='sum')
|
||||
# model.print_readable()
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
train_nodes, test_nodes = train_test_split(
|
||||
torch.arange(data['function'].num_nodes), test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
train_mask = torch.zeros(data['function'].num_nodes, dtype=torch.bool)
|
||||
# train_mask[train_nodes] = True
|
||||
train_mask[:] = True
|
||||
|
||||
test_mask = torch.zeros(data['function'].num_nodes, dtype=torch.bool)
|
||||
# test_mask[test_nodes] = True
|
||||
test_mask[:] = True
|
||||
|
||||
data['function'].train_mask = train_mask
|
||||
data['function'].test_mask = test_mask
|
||||
|
||||
logger.debug("training")
|
||||
for epoch in range(999):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# don't use edge attrs right now.
|
||||
out = model(data.x_dict, data.edge_index_dict) # data.edge_attr_dict)
|
||||
|
||||
out_function = out['function']
|
||||
y_function = data['function'].y
|
||||
|
||||
mask = data['function'].train_mask
|
||||
|
||||
# When classifying "function has string reference"
|
||||
# there is a major class imbalance, because 95% of function's don't reference a string,
|
||||
# so the model just learns to predict "no".
|
||||
# Therefore, weight the classes so that a "yes" prediction is much more valuable.
|
||||
class_counts = torch.bincount(data['function'].y[mask])
|
||||
class_weights = 1.0 / class_counts.float()
|
||||
class_weights = class_weights / class_weights.sum() * len(class_counts)
|
||||
|
||||
# CrossEntropyLoss(): the most common choice for node classification with mutually exclusive classes.
|
||||
# BCEWithLogitsLoss(): multi-label node classification
|
||||
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
|
||||
|
||||
loss = criterion(out_function[mask], y_function[mask])
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
logger.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
|
||||
if loss <= 0.0001:
|
||||
logger.info("no more loss")
|
||||
break
|
||||
|
||||
logger.debug("evaluating")
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
out = model(data.x_dict, data.edge_index_dict) # TODO: edge attrs
|
||||
|
||||
mask = data['function'].test_mask
|
||||
pred = torch.argmax(out['function'][mask], dim=1)
|
||||
truth = data['function'].y[mask].int()
|
||||
|
||||
print("pred", pred[:32])
|
||||
print("truth", truth[:32])
|
||||
# print("index", data['function'].node_id[mask])
|
||||
# print("83: ", g.nodes[lg.mapping['function'][83]]['repr'])
|
||||
|
||||
accuracy = (pred == truth).float().mean()
|
||||
|
||||
# pred = (out[data['function'].test_mask] > 0).int().squeeze()
|
||||
# accuracy = (pred == data['function'].y[data['function'].test_mask]).float().mean()
|
||||
print(f'Accuracy: {accuracy:.4f}')
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user