From f296e7d423f9c4eb30a65c734935a458456c1b03 Mon Sep 17 00:00:00 2001 From: Willi Ballenthin Date: Wed, 6 Nov 2024 12:18:41 +0000 Subject: [PATCH] lints --- capa/analysis/libraries.py | 4 ++- scripts/codecut.py | 65 +++++++++++++++++++------------------- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/capa/analysis/libraries.py b/capa/analysis/libraries.py index d2950056..6dc2d867 100644 --- a/capa/analysis/libraries.py +++ b/capa/analysis/libraries.py @@ -177,7 +177,9 @@ def main(argv=None): for va in idautils.Functions(): name = idaapi.get_func_name(va) - if name not in {"WinMain", }: + if name not in { + "WinMain", + }: continue function_classifications.append( diff --git a/scripts/codecut.py b/scripts/codecut.py index 94dfdf2c..2b12fe38 100644 --- a/scripts/codecut.py +++ b/scripts/codecut.py @@ -3,17 +3,15 @@ import json import logging import sqlite3 import argparse -from typing import Optional +from typing import Iterator, Optional from pathlib import Path from dataclasses import dataclass -import rich -import rich.table import pefile import lancelot -import lancelot.be2utils import networkx as nx -from lancelot.be2utils import BinExport2Index,ReadMemoryError, AddressSpace +import lancelot.be2utils +from lancelot.be2utils import AddressSpace, BinExport2Index, ReadMemoryError from lancelot.be2utils.binexport2_pb2 import BinExport2 import capa.main @@ -21,7 +19,6 @@ import capa.main logger = logging.getLogger(__name__) - def is_vertex_type(vertex: BinExport2.CallGraph.Vertex, type_: BinExport2.CallGraph.Vertex.Type.ValueType) -> bool: return vertex.HasField("type") and vertex.type == type_ @@ -144,7 +141,8 @@ class Assemblage: self.conn = sqlite3.connect(self.db) with self.conn: - self.conn.executescript(""" + self.conn.executescript( + """ PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL; PRAGMA busy_timeout = 5000; @@ -156,8 +154,8 @@ class Assemblage: CREATE INDEX IF NOT EXISTS idx__functions__binary_id ON functions (binary_id); CREATE INDEX IF NOT EXISTS idx__rvas__function_id ON rvas (function_id); - CREATE VIEW IF NOT EXISTS assemblage AS - SELECT + CREATE VIEW IF NOT EXISTS assemblage AS + SELECT binaries.id AS binary_id, binaries.file_name AS file_name, binaries.platform AS platform, @@ -183,19 +181,20 @@ class Assemblage: rvas.id AS rva_id, rvas.start AS start_rva, rvas.end AS end_rva - FROM binaries + FROM binaries JOIN functions ON binaries.id = functions.binary_id JOIN rvas ON functions.id = rvas.function_id; - """) + """ + ) def get_row_by_binary_id(self, binary_id: int) -> AssemblageRow: with self.conn: - cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id, )) + cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ? LIMIT 1;", (binary_id,)) return AssemblageRow(*cur.fetchone()) - def get_rows_by_binary_id(self, binary_id: int) -> AssemblageRow: + def get_rows_by_binary_id(self, binary_id: int) -> Iterator[AssemblageRow]: with self.conn: - cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id, )) + cur = self.conn.execute("SELECT * FROM assemblage WHERE binary_id = ?;", (binary_id,)) row = cur.fetchone() while row: yield AssemblageRow(*row) @@ -203,14 +202,13 @@ class Assemblage: def get_path_by_binary_id(self, binary_id: int) -> Path: with self.conn: - cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id, )) + cur = self.conn.execute("""SELECT path FROM assemblage WHERE binary_id = ? LIMIT 1""", (binary_id,)) return self.samples / cur.fetchone()[0] def get_pe_by_binary_id(self, binary_id: int) -> pefile.PE: path = self.get_path_by_binary_id(binary_id) return pefile.PE(data=path.read_bytes(), fast_load=True) - def generate_main(args: argparse.Namespace) -> int: if not args.assemblage_database.is_file(): @@ -240,11 +238,7 @@ def generate_main(args: argparse.Namespace) -> int: pe_path = db.get_path_by_binary_id(args.binary_id) be2: BinExport2 = lancelot.get_binexport2_from_bytes( - pe_path.read_bytes(), - function_hints=[ - base_address + function.start_rva - for function in functions - ] + pe_path.read_bytes(), function_hints=[base_address + function.start_rva for function in functions] ) idx = lancelot.be2utils.BinExport2Index(be2) @@ -253,7 +247,7 @@ def generate_main(args: argparse.Namespace) -> int: g = nx.MultiDiGraph() - for flow_graph_index, flow_graph in enumerate(be2.flow_graph): + for flow_graph in be2.flow_graph: datas: set[int] = set() callees: set[str] = set() @@ -263,7 +257,7 @@ def generate_main(args: argparse.Namespace) -> int: for basic_block_index in flow_graph.basic_block_index: basic_block: BinExport2.BasicBlock = be2.basic_block[basic_block_index] - for instruction_index, instruction, instruction_address in idx.basic_block_instructions(basic_block): + for instruction_index, instruction, _ in idx.basic_block_instructions(basic_block): for addr in instruction.call_target: addr = thunks.get(addr, addr) @@ -277,7 +271,9 @@ def generate_main(args: argparse.Namespace) -> int: callees.add(vertex.address) - for data_reference_index in idx.data_reference_index_by_source_instruction_index.get(instruction_index, []): + for data_reference_index in idx.data_reference_index_by_source_instruction_index.get( + instruction_index, [] + ): data_reference: BinExport2.DataReference = be2.data_reference[data_reference_index] data_reference_address: int = data_reference.address @@ -336,12 +332,15 @@ def generate_main(args: argparse.Namespace) -> int: # within each section, emit a neighbor edge for each pair of neighbors. section_nodes = [ - node for node, attrs in g.nodes(data=True) - if (section.VirtualAddress + base_address) <= attrs["address"] < (base_address + section.VirtualAddress + section.Misc_VirtualSize) + node + for node, attrs in g.nodes(data=True) + if (section.VirtualAddress + base_address) + <= attrs["address"] + < (base_address + section.VirtualAddress + section.Misc_VirtualSize) ] for i in range(1, len(section_nodes)): - a = section_nodes[i-1] + a = section_nodes[i - 1] b = section_nodes[i] g.add_edge( @@ -353,8 +352,8 @@ def generate_main(args: argparse.Namespace) -> int: ) for function in functions: - g.nodes[base_address+function.start_rva]["name"] = function.name - g.nodes[base_address+function.start_rva]["file"] = function.file + g.nodes[base_address + function.start_rva]["name"] = function.name + g.nodes[base_address + function.start_rva]["file"] = function.file # rename unknown functions like: sub_401000 for n, attrs in g.nodes(data=True): @@ -373,7 +372,7 @@ def generate_main(args: argparse.Namespace) -> int: attrs["repr"] = attrs["name"] attrs["is_import"] = "!" in attrs["name"] case "data": - if (string := read_string(address_space, n)): + if string := read_string(address_space, n): attrs["repr"] = json.dumps(string) attrs["is_string"] = True else: @@ -384,6 +383,7 @@ def generate_main(args: argparse.Namespace) -> int: print(line) # db.conn.close() + return 0 def cluster_main(args: argparse.Namespace) -> int: @@ -391,7 +391,7 @@ def cluster_main(args: argparse.Namespace) -> int: raise ValueError("graph file doesn't exist") g = nx.read_gexf(args.graph) - + communities = nx.algorithms.community.louvain_communities(g) for i, community in enumerate(communities): print(f"[{i}]:") @@ -401,6 +401,8 @@ def cluster_main(args: argparse.Namespace) -> int: else: print(f" - {hex(int(node, 0))}") + return 0 + def main(argv=None) -> int: if argv is None: @@ -416,7 +418,6 @@ def main(argv=None) -> int: generate_parser.add_argument("binary_id", type=int, help="primary key of binary to inspect") generate_parser.set_defaults(func=generate_main) - cluster_parser = subparsers.add_parser("cluster", help="cluster an existing graph") cluster_parser.add_argument("graph", type=Path, help="path to a graph file") cluster_parser.set_defaults(func=cluster_main)