This commit is contained in:
Willi Ballenthin
2025-01-14 09:18:06 +00:00
parent 891fa8aaa3
commit d89083ab31

View File

@@ -310,6 +310,7 @@ def generate_main(args: argparse.Namespace) -> int:
for data_address in sorted(datas):
logger.info(" - 0x%X", data_address)
# TODO: check if this is already a function
g.add_node(
make_node_id(data_address),
address=data_address,
@@ -520,9 +521,8 @@ def cluster_main(args: argparse.Namespace) -> int:
return 0
# uv pip install torch --index-url https://download.pytorch.org/whl/cpu
# uv pip install torch-geometric pandas numpy
# uv pip install torch-geometric pandas numpy scikit-learn
# import torch # do this on-demand below, because its slow
# from torch_geometric.data import HeteroData
@@ -548,6 +548,8 @@ NODE_TYPES = {
type="function",
attributes={
"is_import": False,
"does_reference_string": False,
# "ground_truth": False,
# unused:
# - repr: str
# - address: int
@@ -585,22 +587,37 @@ EDGE_TYPES = {
destination_type=DATA_NODE,
attributes={},
),
EdgeType(
# When functions reference other functions as data,
# such as passing a function pointer as a callback.
#
# Example:
# __scrt_set_unhandled_exception_filter > reference > __scrt_unhandled_exception_filter
key="reference",
source_type=FUNCTION_NODE,
destination_type=FUNCTION_NODE,
attributes={},
),
EdgeType(
key="neighbor",
source_type=FUNCTION_NODE,
destination_type=FUNCTION_NODE,
attributes={
# this is the attribute to predict
"is_source_file_boundary": False,
# this is the attribute to predict (ultimately)
# "is_source_file_boundary": False,
"distance": 1,
},
),
EdgeType(
key="neighbor",
source_type=DATA_NODE,
destination_type=DATA_NODE,
# attributes={
# },
attributes={
# this is the attribute to predict
"is_source_file_boundary": False,
# this is the attribute to predict (ultimately)
# "is_source_file_boundary": False,
"distance": 1,
},
),
]
@@ -610,8 +627,10 @@ EDGE_TYPES = {
@dataclass
class LoadedGraph:
data: "HeteroData"
# map from node type to:
# map from node id (str) to node index (int), and node index (int) to node id (str).
mapping: dict[str | int, int | str]
mapping: dict[str, dict[str | int, int | str]]
def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
@@ -619,19 +638,26 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
from torch_geometric.data import HeteroData
# Our networkx graph identifies nodes by str ("sha256:address").
# Torch identifies nodes by index, from 0 to #nodes.
# Torch identifies nodes by index, from 0 to #nodes, for each type of node.
# Map one to another.
node_indexes_by_node: dict[str, int] = {}
nodes_by_node_index: dict[int, str] = {}
node_indexes_by_node: dict[str, dict[str, int]] = {n: {} for n in NODE_TYPES.keys()}
# Because the types are different (str and int),
# here's a single mapping where the type of the key implies
# the sort of lookup you're doing (by index (int) or by node id (str)).
node_mapping: dict[str | int, int | str] = {}
for i, node in enumerate(sorted(g.nodes)):
node_indexes_by_node[node] = i
nodes_by_node_index[i] = node
node_mapping[node] = i
node_mapping[i] = node
node_mapping: dict[str, dict[str | int, int | str]] = {n: {} for n in NODE_TYPES.keys()}
for node_type in NODE_TYPES.keys():
def is_this_node_type(node_attrs):
node, attrs = node_attrs
return attrs["type"] == node_type
ns = g.nodes(data=True)
ns = sorted(ns)
ns = filter(is_this_node_type, ns)
ns = map(lambda p: p[0], ns)
for i, node in enumerate(ns):
node_indexes_by_node[node_type][node] = i
node_mapping[node_type][node] = i
node_mapping[node_type][i] = node
data = HeteroData()
@@ -645,7 +671,7 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
if attrs["type"] != node_type.type:
continue
node_index = node_indexes_by_node[node]
node_index = node_indexes_by_node[node_type.type][node]
node_indexes.append(node_index)
for attribute, default_value in node_type.attributes.items():
@@ -675,8 +701,11 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph:
if g.nodes[destination]["type"] != edge_type.destination_type.type:
continue
source_index = node_indexes_by_node[source]
destination_index = node_indexes_by_node[destination]
# These are global node indexes
# but we need to provide the node type-local index.
# That is, functions have their own node indexes, 0 to N. data have their own node indexes, 0 to N.
source_index = node_indexes_by_node[g.nodes[source]["type"]][source]
destination_index = node_indexes_by_node[g.nodes[destination]["type"]][destination]
source_indexes.append(source_index)
destination_indexes.append(destination_index)
@@ -710,11 +739,184 @@ def train_main(args: argparse.Namespace) -> int:
logger.debug("loading torch")
import torch
import random
import numpy as np
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
logger.debug("reading graph from disk")
g = nx.read_gexf(args.graph)
lg = load_graph(g)
# Initial model: learn to find functions that reference a string.
#
# Once this works, then we can try a more complex model (edge features),
# and ultimately an edge classifier.
#
# Ground truth from existing patterns like:
#
# function > references > data (:is_string=True)
print(lg.data)
for a, b, key, attrs in g.edges(data=True, keys=True):
match (g.nodes[a]["type"], key, g.nodes[b]["type"]):
case ("function", "reference", "data"):
if g.nodes[b].get("is_string"):
g.nodes[a]["does_reference_string"] = True
logger.debug("%s > reference > %s (string)", g.nodes[a]["repr"], g.nodes[b]["repr"])
case ("function", "reference", "function"):
# The data model supports this.
# Like passing a function pointer as a callback
continue
case ("data", "reference", "data"):
# We don't support this.
continue
case ("data", "reference", "function"):
# We don't support this.
continue
case (_, "call", _):
continue
case (_, "neighbor", _):
continue
case _:
print(a, b, key, attrs, g.nodes[a], g.nodes[b])
raise ValueError("unexpected structure")
# map existing attributes to the ground_truth attribute
# for ease of updating the model/training.
for node, attrs in g.nodes(data=True):
if attrs["type"] != "function":
continue
attrs["ground_truth"] = attrs.get("does_reference_string", False)
logger.debug("loading graph into torch")
lg = load_graph(g)
data = lg.data
data['data'].y = torch.zeros(data['data'].num_nodes, dtype=torch.long)
data['function'].y = torch.zeros(data['function'].num_nodes, dtype=torch.long)
true_indices = []
for node, attrs in g.nodes(data=True):
if attrs.get("ground_truth"):
print("true: ", g.nodes[node]["repr"])
node_index = lg.mapping[attrs["type"]][node]
print("index", attrs["type"], node_index)
print(" ", node)
print(" ", lg.mapping[attrs["type"]][node_index])
true_indices.append(node_index)
# true_indices.append(data['function'].node_id[node_index].item())
# print("true index: ", node_index, data['function'].node_id[node_index].item())
data['function'].y[true_indices] = 1
print(data['function'].y)
# TODO
import torch_geometric.transforms as T
data = T.ToUndirected()(data)
# data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)
print(data)
from torch_geometric.nn import RGCNConv, to_hetero, SAGEConv, Linear
import torch.nn.functional as F
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), hidden_channels)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
x = self.lin(x)
return x
model = GNN(hidden_channels=4, out_channels=2)
# metadata: tuple[list of node types, list of edge types (source, key, dest)]
model = to_hetero(model, data.metadata(), aggr='sum')
# model.print_readable()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
from sklearn.model_selection import train_test_split
train_nodes, test_nodes = train_test_split(
torch.arange(data['function'].num_nodes), test_size=0.2, random_state=42
)
train_mask = torch.zeros(data['function'].num_nodes, dtype=torch.bool)
# train_mask[train_nodes] = True
train_mask[:] = True
test_mask = torch.zeros(data['function'].num_nodes, dtype=torch.bool)
# test_mask[test_nodes] = True
test_mask[:] = True
data['function'].train_mask = train_mask
data['function'].test_mask = test_mask
logger.debug("training")
for epoch in range(999):
model.train()
optimizer.zero_grad()
# don't use edge attrs right now.
out = model(data.x_dict, data.edge_index_dict) # data.edge_attr_dict)
out_function = out['function']
y_function = data['function'].y
mask = data['function'].train_mask
# When classifying "function has string reference"
# there is a major class imbalance, because 95% of function's don't reference a string,
# so the model just learns to predict "no".
# Therefore, weight the classes so that a "yes" prediction is much more valuable.
class_counts = torch.bincount(data['function'].y[mask])
class_weights = 1.0 / class_counts.float()
class_weights = class_weights / class_weights.sum() * len(class_counts)
# CrossEntropyLoss(): the most common choice for node classification with mutually exclusive classes.
# BCEWithLogitsLoss(): multi-label node classification
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
loss = criterion(out_function[mask], y_function[mask])
loss.backward()
optimizer.step()
logger.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if loss <= 0.0001:
logger.info("no more loss")
break
logger.debug("evaluating")
model.eval()
with torch.no_grad():
out = model(data.x_dict, data.edge_index_dict) # TODO: edge attrs
mask = data['function'].test_mask
pred = torch.argmax(out['function'][mask], dim=1)
truth = data['function'].y[mask].int()
print("pred", pred[:32])
print("truth", truth[:32])
# print("index", data['function'].node_id[mask])
# print("83: ", g.nodes[lg.mapping['function'][83]]['repr'])
accuracy = (pred == truth).float().mean()
# pred = (out[data['function'].test_mask] > 0).int().squeeze()
# accuracy = (pred == data['function'].y[data['function'].test_mask]).float().mean()
print(f'Accuracy: {accuracy:.4f}')
return 0