From 891fa8aaa3f42653b4821b20790bbb8ee6681638 Mon Sep 17 00:00:00 2001 From: Willi Ballenthin Date: Thu, 14 Nov 2024 10:32:07 +0000 Subject: [PATCH] codecut: torch loader --- scripts/codecut.py | 113 ++++++++++++++++++--------------------------- 1 file changed, 46 insertions(+), 67 deletions(-) diff --git a/scripts/codecut.py b/scripts/codecut.py index ae5077e7..59fd2402 100644 --- a/scripts/codecut.py +++ b/scripts/codecut.py @@ -5,7 +5,7 @@ import logging import sqlite3 import argparse import subprocess -from typing import Iterator, Optional +from typing import Iterator, Optional, Literal from pathlib import Path from dataclasses import dataclass from multiprocessing import Pool @@ -521,20 +521,16 @@ def cluster_main(args: argparse.Namespace) -> int: -# uv pip install torch --index-url https://download.pytorch.org/whl/cpu +# uv pip install torch --index-url https://download.pytorch.org/whl/cpu # uv pip install torch-geometric pandas numpy -import pandas as pd -import numpy as np -import torch_geometric.utils.convert - -import torch -from torch_geometric.data import HeteroData +# import torch # do this on-demand below, because its slow +# from torch_geometric.data import HeteroData @dataclass class NodeType: type: str - attributes: Dict[str, int | bool] + attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float] @dataclass @@ -542,7 +538,7 @@ class EdgeType: key: str source_type: NodeType destination_type: NodeType - attributes: Dict[str, int | bool] + attributes: dict[str, Literal[False] | Literal[""] | Literal[0] | float] NODE_TYPES = { @@ -551,22 +547,22 @@ NODE_TYPES = { NodeType( type="function", attributes={ - "is_import": bool, + "is_import": False, # unused: # - repr: str # - address: int # - name: str # - file: str - } + }, ), NodeType( type="data", attributes={ - "is_string": bool, + "is_string": False, # unused: # - repr: str # - address: int - } + }, ), ] } @@ -613,12 +609,15 @@ EDGE_TYPES = { @dataclass class LoadedGraph: - data: HeteroGraph + data: "HeteroData" # map from node id (str) to node index (int), and node index (int) to node id (str). mapping: dict[str | int, int | str] def load_graph(g: nx.MultiDiGraph) -> LoadedGraph: + import torch + from torch_geometric.data import HeteroData + # Our networkx graph identifies nodes by str ("sha256:address"). # Torch identifies nodes by index, from 0 to #nodes. # Map one to another. @@ -627,8 +626,8 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph: # Because the types are different (str and int), # here's a single mapping where the type of the key implies # the sort of lookup you're doing (by index (int) or by node id (str)). - node_mapping = dict[str | int, int | str] = {} - for i, node in enumerate(sort(g.nodes)): + node_mapping: dict[str | int, int | str] = {} + for i, node in enumerate(sorted(g.nodes)): node_indexes_by_node[node] = i nodes_by_node_index[i] = node node_mapping[node] = i @@ -637,11 +636,10 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph: data = HeteroData() for node_type in NODE_TYPES.values(): + logger.debug("loading nodes: %s", node_type.type) + node_indexes: list[int] = [] - attr_values: dict[str, list] = { - attribute: [] - for attribute in node_type.attributes.keys() - } + attr_values: dict[str, list] = {attribute: [] for attribute in node_type.attributes.keys()} for node, attrs in g.nodes(data=True): if attrs["type"] != node_type.type: @@ -655,16 +653,19 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph: attr_values[attribute].append(value) data[node_type.type].node_id = torch.tensor(node_indexes) - # attribute order is implicit in the NODE_TYPES data model above. - data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float() + if attr_values: + # attribute order is implicit in the NODE_TYPES data model above. + data[node_type.type].x = torch.stack([torch.tensor(values) for values in attr_values.values()], dim=-1).float() for edge_type in EDGE_TYPES.values(): + logger.debug( + "loading edges: %s > %s > %s", + edge_type.source_type.type, edge_type.key, edge_type.destination_type.type + ) + source_indexes: list[int] = [] destination_indexes: list[int] = [] - attr_values: dict[str, list] = { - attribute: [] - for attribute in edge_type.attributes.keys() - } + attr_values: dict[str, list] = {attribute: [] for attribute in edge_type.attributes.keys()} for source, destination, key, attrs in g.edges(data=True, keys=True): if key != edge_type.key: @@ -684,58 +685,36 @@ def load_graph(g: nx.MultiDiGraph) -> LoadedGraph: value = attrs.get(attribute, default_value) attr_values[attribute].append(value) - data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack([ - torch.tensor(source_indexes), - torch.tensor(destination_indexes), - ]) - # attribute order is implicit in the EDGE_TYPES data model above. - data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack([ - torch.tensor(values) for values in attr_values.values() - ], dim=-1).float() + data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_index = torch.stack( + [ + torch.tensor(source_indexes), + torch.tensor(destination_indexes), + ] + ) + if attr_values: + # attribute order is implicit in the EDGE_TYPES data model above. + data[edge_type.source_type.type, edge_type.key, edge_type.destination_type.type].edge_attr = torch.stack( + [torch.tensor(values) for values in attr_values.values()], dim=-1 + ).float() - return LoadedData( + return LoadedGraph( data, node_mapping, ) - + def train_main(args: argparse.Namespace) -> int: if not args.graph.is_file(): raise ValueError("graph file doesn't exist") + logger.debug("loading torch") + import torch + g = nx.read_gexf(args.graph) - # set node default attributes - for _, attrs in g.nodes(data=True): - if "type" not in attrs: - raise ValueError("node missing `type`") + lg = load_graph(g) - for key, value in { - "name": "", - "file": "", - "repr": "(unknown)", - "is_import": False, - "is_string": False, - }.items(): - if key not in attrs: - attrs[key] = value - - # set edge default attributes - for a, b, key, attrs in g.edges(data=True, keys=True): - if "key" not in attrs: - raise ValueError("edge missing `key`", a, b, key, attrs) - - for key, value in { - # TODO: this should only be on neighbor edges, for multi-graphs - "is_source_file_boundary": false, - }.items(): - if key not in attrs: - attrs[key] = value - - # TODO: this is only for Graph or DiGraph, not MultiGraph - graph = torch_geometric.utils.convert.from_networkx(g) - - print(graph) + print(lg.data) return 0