introduce flake8-simplify

This commit is contained in:
Willi Ballenthin
2023-07-12 11:40:44 +02:00
parent 7526ff876f
commit 65e8300145
28 changed files with 113 additions and 130 deletions

11
.github/flake8.ini vendored
View File

@@ -13,7 +13,16 @@ extend-ignore =
# B010 Do not call setattr with a constant attribute value # B010 Do not call setattr with a constant attribute value
B010, B010,
# G200 Logging statement uses exception in arguments # G200 Logging statement uses exception in arguments
G200 G200,
# SIM102 Use a single if-statement instead of nested if-statements
# doesn't provide a space for commenting or logical separation of conditions
SIM102,
# SIM114 Use logical or and a single body
# makes logic trees too complex
SIM114,
# SIM117 Use 'with Foo, Bar:' instead of multiple with statements
# makes lines too long
SIM117
per-file-ignores = per-file-ignores =

View File

@@ -130,7 +130,7 @@ def is_mov_imm_to_stack(il: MediumLevelILInstruction) -> bool:
if il.src.operation != MediumLevelILOperation.MLIL_CONST: if il.src.operation != MediumLevelILOperation.MLIL_CONST:
return False return False
if not il.dest.source_type == VariableSourceType.StackVariableSourceType: if il.dest.source_type != VariableSourceType.StackVariableSourceType:
return False return False
return True return True

View File

@@ -53,9 +53,7 @@ class BinjaFeatureExtractor(FeatureExtractor):
mlil_lookup[mlil_bb.source_block.start] = mlil_bb mlil_lookup[mlil_bb.source_block.start] = mlil_bb
for bb in f.basic_blocks: for bb in f.basic_blocks:
mlil_bb = None mlil_bb = mlil_lookup.get(bb.start)
if bb.start in mlil_lookup:
mlil_bb = mlil_lookup[bb.start]
yield BBHandle(address=AbsoluteVirtualAddress(bb.start), inner=(bb, mlil_bb)) yield BBHandle(address=AbsoluteVirtualAddress(bb.start), inner=(bb, mlil_bb))

View File

@@ -155,8 +155,7 @@ def extract_insn_number_features(
for llil in func.get_llils_at(ih.address): for llil in func.get_llils_at(ih.address):
visit_llil_exprs(llil, llil_checker) visit_llil_exprs(llil, llil_checker)
for result in results: yield from results
yield result
def extract_insn_bytes_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[Tuple[Feature, Address]]: def extract_insn_bytes_features(fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle) -> Iterator[Tuple[Feature, Address]]:
@@ -318,8 +317,7 @@ def extract_insn_offset_features(
for llil in func.get_llils_at(ih.address): for llil in func.get_llils_at(ih.address):
visit_llil_exprs(llil, llil_checker) visit_llil_exprs(llil, llil_checker)
for result in results: yield from results
yield result
def is_nzxor_stack_cookie(f: Function, bb: BinjaBasicBlock, llil: LowLevelILInstruction) -> bool: def is_nzxor_stack_cookie(f: Function, bb: BinjaBasicBlock, llil: LowLevelILInstruction) -> bool:
@@ -375,8 +373,7 @@ def extract_insn_nzxor_characteristic_features(
for llil in func.get_llils_at(ih.address): for llil in func.get_llils_at(ih.address):
visit_llil_exprs(llil, llil_checker) visit_llil_exprs(llil, llil_checker)
for result in results: yield from results
yield result
def extract_insn_mnemonic_features( def extract_insn_mnemonic_features(
@@ -438,8 +435,7 @@ def extract_insn_peb_access_characteristic_features(
for llil in func.get_llils_at(ih.address): for llil in func.get_llils_at(ih.address):
visit_llil_exprs(llil, llil_checker) visit_llil_exprs(llil, llil_checker)
for result in results: yield from results
yield result
def extract_insn_segment_access_features( def extract_insn_segment_access_features(
@@ -466,8 +462,7 @@ def extract_insn_segment_access_features(
for llil in func.get_llils_at(ih.address): for llil in func.get_llils_at(ih.address):
visit_llil_exprs(llil, llil_checker) visit_llil_exprs(llil, llil_checker)
for result in results: yield from results
yield result
def extract_insn_cross_section_cflow( def extract_insn_cross_section_cflow(

View File

@@ -53,19 +53,19 @@ class DnFileFeatureExtractorCache:
self.types[type_.token] = type_ self.types[type_.token] = type_
def get_import(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]: def get_import(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]:
return self.imports.get(token, None) return self.imports.get(token)
def get_native_import(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]: def get_native_import(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]:
return self.native_imports.get(token, None) return self.native_imports.get(token)
def get_method(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]: def get_method(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]:
return self.methods.get(token, None) return self.methods.get(token)
def get_field(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]: def get_field(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]:
return self.fields.get(token, None) return self.fields.get(token)
def get_type(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]: def get_type(self, token: int) -> Optional[Union[DnType, DnUnmanagedMethod]]:
return self.types.get(token, None) return self.types.get(token)
class DnfileFeatureExtractor(FeatureExtractor): class DnfileFeatureExtractor(FeatureExtractor):
@@ -120,7 +120,7 @@ class DnfileFeatureExtractor(FeatureExtractor):
address: DNTokenAddress = DNTokenAddress(insn.operand.value) address: DNTokenAddress = DNTokenAddress(insn.operand.value)
# record call to destination method; note: we only consider MethodDef methods for destinations # record call to destination method; note: we only consider MethodDef methods for destinations
dest: Optional[FunctionHandle] = methods.get(address, None) dest: Optional[FunctionHandle] = methods.get(address)
if dest is not None: if dest is not None:
dest.ctx["calls_to"].add(fh.address) dest.ctx["calls_to"].add(fh.address)

View File

@@ -52,7 +52,7 @@ def resolve_dotnet_token(pe: dnfile.dnPE, token: Token) -> Union[dnfile.base.MDT
return InvalidToken(token.value) return InvalidToken(token.value)
return user_string return user_string
table: Optional[dnfile.base.ClrMetaDataTable] = pe.net.mdtables.tables.get(token.table, None) table: Optional[dnfile.base.ClrMetaDataTable] = pe.net.mdtables.tables.get(token.table)
if table is None: if table is None:
# table index is not valid # table index is not valid
return InvalidToken(token.value) return InvalidToken(token.value)
@@ -204,7 +204,7 @@ def get_dotnet_managed_methods(pe: dnfile.dnPE) -> Iterator[DnType]:
continue continue
token: int = calculate_dotnet_token_value(method.table.number, method.row_index) token: int = calculate_dotnet_token_value(method.table.number, method.row_index)
access: Optional[str] = accessor_map.get(token, None) access: Optional[str] = accessor_map.get(token)
method_name: str = method.row.Name method_name: str = method.row.Name
if method_name.startswith(("get_", "set_")): if method_name.startswith(("get_", "set_")):

View File

@@ -9,7 +9,7 @@
from typing import Optional from typing import Optional
class DnType(object): class DnType:
def __init__(self, token: int, class_: str, namespace: str = "", member: str = "", access: Optional[str] = None): def __init__(self, token: int, class_: str, namespace: str = "", member: str = "", access: Optional[str] = None):
self.token: int = token self.token: int = token
self.access: Optional[str] = access self.access: Optional[str] = access

View File

@@ -706,8 +706,7 @@ class SymTab:
return a tuple: (name, value, size, info, other, shndx) return a tuple: (name, value, size, info, other, shndx)
for each symbol contained in the symbol table for each symbol contained in the symbol table
""" """
for symbol in self.symbols: yield from self.symbols
yield symbol
@classmethod @classmethod
def from_Elf(cls, ElfBinary) -> Optional["SymTab"]: def from_Elf(cls, ElfBinary) -> Optional["SymTab"]:

View File

@@ -122,7 +122,7 @@ def get_file_externs() -> Dict[int, Tuple[str, str, int]]:
externs = {} externs = {}
for seg in get_segments(skip_header_segments=True): for seg in get_segments(skip_header_segments=True):
if not (seg.type == ida_segment.SEG_XTRN): if seg.type != ida_segment.SEG_XTRN:
continue continue
for ea in idautils.Functions(seg.start_ea, seg.end_ea): for ea in idautils.Functions(seg.start_ea, seg.end_ea):
@@ -275,20 +275,18 @@ def is_op_offset(insn: idaapi.insn_t, op: idaapi.op_t) -> bool:
def is_sp_modified(insn: idaapi.insn_t) -> bool: def is_sp_modified(insn: idaapi.insn_t) -> bool:
"""determine if instruction modifies SP, ESP, RSP""" """determine if instruction modifies SP, ESP, RSP"""
for op in get_insn_ops(insn, target_ops=(idaapi.o_reg,)): return any(
if op.reg == idautils.procregs.sp.reg and is_op_write(insn, op): op.reg == idautils.procregs.sp.reg and is_op_write(insn, op)
# register is stack and written for op in get_insn_ops(insn, target_ops=(idaapi.o_reg,))
return True )
return False
def is_bp_modified(insn: idaapi.insn_t) -> bool: def is_bp_modified(insn: idaapi.insn_t) -> bool:
"""check if instruction modifies BP, EBP, RBP""" """check if instruction modifies BP, EBP, RBP"""
for op in get_insn_ops(insn, target_ops=(idaapi.o_reg,)): return any(
if op.reg == idautils.procregs.bp.reg and is_op_write(insn, op): op.reg == idautils.procregs.bp.reg and is_op_write(insn, op)
# register is base and written for op in get_insn_ops(insn, target_ops=(idaapi.o_reg,))
return True )
return False
def is_frame_register(reg: int) -> bool: def is_frame_register(reg: int) -> bool:
@@ -334,10 +332,7 @@ def mask_op_val(op: idaapi.op_t) -> int:
def is_function_recursive(f: idaapi.func_t) -> bool: def is_function_recursive(f: idaapi.func_t) -> bool:
"""check if function is recursive""" """check if function is recursive"""
for ref in idautils.CodeRefsTo(f.start_ea, True): return any(f.contains(ref) for ref in idautils.CodeRefsTo(f.start_ea, True))
if f.contains(ref):
return True
return False
def is_basic_block_tight_loop(bb: idaapi.BasicBlock) -> bool: def is_basic_block_tight_loop(bb: idaapi.BasicBlock) -> bool:
@@ -386,8 +381,7 @@ def find_data_reference_from_insn(insn: idaapi.insn_t, max_depth: int = 10) -> i
def get_function_blocks(f: idaapi.func_t) -> Iterator[idaapi.BasicBlock]: def get_function_blocks(f: idaapi.func_t) -> Iterator[idaapi.BasicBlock]:
"""yield basic blocks contained in specified function""" """yield basic blocks contained in specified function"""
# leverage idaapi.FC_NOEXT flag to ignore useless external blocks referenced by the function # leverage idaapi.FC_NOEXT flag to ignore useless external blocks referenced by the function
for block in idaapi.FlowChart(f, flags=(idaapi.FC_PREDS | idaapi.FC_NOEXT)): yield from idaapi.FlowChart(f, flags=(idaapi.FC_PREDS | idaapi.FC_NOEXT))
yield block
def is_basic_block_return(bb: idaapi.BasicBlock) -> bool: def is_basic_block_return(bb: idaapi.BasicBlock) -> bool:

View File

@@ -216,7 +216,7 @@ def extract_insn_offset_features(
p_info = capa.features.extractors.ida.helpers.get_op_phrase_info(op) p_info = capa.features.extractors.ida.helpers.get_op_phrase_info(op)
op_off = p_info.get("offset", None) op_off = p_info.get("offset")
if op_off is None: if op_off is None:
continue continue
@@ -447,7 +447,7 @@ def extract_insn_cross_section_cflow(
insn: idaapi.insn_t = ih.inner insn: idaapi.insn_t = ih.inner
for ref in idautils.CodeRefsFrom(insn.ea, False): for ref in idautils.CodeRefsFrom(insn.ea, False):
if ref in get_imports(fh.ctx).keys(): if ref in get_imports(fh.ctx):
# ignore API calls # ignore API calls
continue continue
if not idaapi.getseg(ref): if not idaapi.getseg(ref):

View File

@@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and limitations under the License. # See the License for the specific language governing permissions and limitations under the License.
import re import re
import contextlib
from collections import namedtuple from collections import namedtuple
ASCII_BYTE = r" !\"#\$%&\'\(\)\*\+,-\./0123456789:;<=>\?@ABCDEFGHIJKLMNOPQRSTUVWXYZ\[\]\^_`abcdefghijklmnopqrstuvwxyz\{\|\}\\\~\t".encode( ASCII_BYTE = r" !\"#\$%&\'\(\)\*\+,-\./0123456789:;<=>\?@ABCDEFGHIJKLMNOPQRSTUVWXYZ\[\]\^_`abcdefghijklmnopqrstuvwxyz\{\|\}\\\~\t".encode(
@@ -81,7 +82,5 @@ def extract_unicode_strings(buf, n=4):
reg = b"((?:[%s]\x00){%d,})" % (ASCII_BYTE, n) reg = b"((?:[%s]\x00){%d,})" % (ASCII_BYTE, n)
r = re.compile(reg) r = re.compile(reg)
for match in r.finditer(buf): for match in r.finditer(buf):
try: with contextlib.suppress(UnicodeDecodeError):
yield String(match.group().decode("utf-16"), match.start()) yield String(match.group().decode("utf-16"), match.start())
except UnicodeDecodeError:
pass

View File

@@ -7,7 +7,7 @@
# See the License for the specific language governing permissions and limitations under the License. # See the License for the specific language governing permissions and limitations under the License.
import collections import collections
from typing import Set, List, Deque, Tuple, Union, Optional from typing import Set, List, Deque, Tuple, Optional
import envi import envi
import vivisect.const import vivisect.const
@@ -71,7 +71,7 @@ class NotFoundError(Exception):
pass pass
def find_definition(vw: VivWorkspace, va: int, reg: int) -> Tuple[int, Union[int, None]]: def find_definition(vw: VivWorkspace, va: int, reg: int) -> Tuple[int, Optional[int]]:
""" """
scan backwards from the given address looking for assignments to the given register. scan backwards from the given address looking for assignments to the given register.
if a constant, return that value. if a constant, return that value.

View File

@@ -410,9 +410,7 @@ def extract_insn_obfs_call_plus_5_characteristic_features(f, bb, ih: InsnHandle)
if insn.va + 5 == insn.opers[0].getOperValue(insn): if insn.va + 5 == insn.opers[0].getOperValue(insn):
yield Characteristic("call $+5"), ih.address yield Characteristic("call $+5"), ih.address
if isinstance(insn.opers[0], envi.archs.i386.disasm.i386ImmMemOper) or isinstance( if isinstance(insn.opers[0], (envi.archs.i386.disasm.i386ImmMemOper, envi.archs.amd64.disasm.Amd64RipRelOper)):
insn.opers[0], envi.archs.amd64.disasm.Amd64RipRelOper
):
if insn.va + 5 == insn.opers[0].getOperAddr(insn): if insn.va + 5 == insn.opers[0].getOperAddr(insn):
yield Characteristic("call $+5"), ih.address yield Characteristic("call $+5"), ih.address

View File

@@ -197,11 +197,11 @@ class CapaRuleGenFeatureCache:
return features, matches return features, matches
def _get_cached_func_node(self, fh: FunctionHandle) -> Optional[CapaRuleGenFeatureCacheNode]: def _get_cached_func_node(self, fh: FunctionHandle) -> Optional[CapaRuleGenFeatureCacheNode]:
f_node: Optional[CapaRuleGenFeatureCacheNode] = self.func_nodes.get(fh.address, None) f_node: Optional[CapaRuleGenFeatureCacheNode] = self.func_nodes.get(fh.address)
if f_node is None: if f_node is None:
# function is not in our cache, do extraction now # function is not in our cache, do extraction now
self._find_function_and_below_features(fh) self._find_function_and_below_features(fh)
f_node = self.func_nodes.get(fh.address, None) f_node = self.func_nodes.get(fh.address)
return f_node return f_node
def get_all_function_features(self, fh: FunctionHandle) -> FeatureSet: def get_all_function_features(self, fh: FunctionHandle) -> FeatureSet:

View File

@@ -1204,11 +1204,11 @@ class CapaExplorerForm(idaapi.PluginForm):
self.set_rulegen_status(f"Failed to create function rule matches from rule set ({e})") self.set_rulegen_status(f"Failed to create function rule matches from rule set ({e})")
return return
if rule.scope == capa.rules.Scope.FUNCTION and rule.name in func_matches.keys(): if rule.scope == capa.rules.Scope.FUNCTION and rule.name in func_matches:
is_match = True is_match = True
elif rule.scope == capa.rules.Scope.BASIC_BLOCK and rule.name in bb_matches.keys(): elif rule.scope == capa.rules.Scope.BASIC_BLOCK and rule.name in bb_matches:
is_match = True is_match = True
elif rule.scope == capa.rules.Scope.INSTRUCTION and rule.name in insn_matches.keys(): elif rule.scope == capa.rules.Scope.INSTRUCTION and rule.name in insn_matches:
is_match = True is_match = True
elif rule.scope == capa.rules.Scope.FILE: elif rule.scope == capa.rules.Scope.FILE:
try: try:
@@ -1216,7 +1216,7 @@ class CapaExplorerForm(idaapi.PluginForm):
except Exception as e: except Exception as e:
self.set_rulegen_status(f"Failed to create file rule matches from rule set ({e})") self.set_rulegen_status(f"Failed to create file rule matches from rule set ({e})")
return return
if rule.name in file_matches.keys(): if rule.name in file_matches:
is_match = True is_match = True
else: else:
is_match = False is_match = False

View File

@@ -30,7 +30,7 @@ class CapaExplorerIdaHooks(idaapi.UI_Hooks):
@retval must be 0 @retval must be 0
""" """
self.process_action_handle = self.process_action_hooks.get(name, None) self.process_action_handle = self.process_action_hooks.get(name)
if self.process_action_handle: if self.process_action_handle:
self.process_action_handle(self.process_action_meta) self.process_action_handle(self.process_action_meta)

View File

@@ -130,8 +130,7 @@ class CapaExplorerDataItem:
def children(self) -> Iterator["CapaExplorerDataItem"]: def children(self) -> Iterator["CapaExplorerDataItem"]:
"""yield children""" """yield children"""
for child in self._children: yield from self._children
yield child
def removeChildren(self): def removeChildren(self):
"""remove children""" """remove children"""

View File

@@ -628,7 +628,7 @@ class CapaExplorerDataModel(QtCore.QAbstractItemModel):
matched_rule_source = "" matched_rule_source = ""
# check if match is a matched rule # check if match is a matched rule
matched_rule = doc.rules.get(feature.match, None) matched_rule = doc.rules.get(feature.match)
if matched_rule is not None: if matched_rule is not None:
matched_rule_source = matched_rule.source matched_rule_source = matched_rule.source

View File

@@ -1224,8 +1224,7 @@ class CapaExplorerQtreeView(QtWidgets.QTreeView):
yield self.new_action(*action) yield self.new_action(*action)
# add default actions # add default actions
for action in self.load_default_context_menu_actions(data): yield from self.load_default_context_menu_actions(data)
yield action
def load_default_context_menu(self, pos, item, model_index): def load_default_context_menu(self, pos, item, model_index):
"""create default custom context menu """create default custom context menu

View File

@@ -327,10 +327,9 @@ def find_capabilities(ruleset: RuleSet, extractor: FeatureExtractor, disable_pro
def has_rule_with_namespace(rules: RuleSet, capabilities: MatchResults, namespace: str) -> bool: def has_rule_with_namespace(rules: RuleSet, capabilities: MatchResults, namespace: str) -> bool:
for rule_name in capabilities.keys(): return any(
if rules.rules[rule_name].meta.get("namespace", "").startswith(namespace): rules.rules[rule_name].meta.get("namespace", "").startswith(namespace) for rule_name in capabilities.keys()
return True )
return False
def is_internal_rule(rule: Rule) -> bool: def is_internal_rule(rule: Rule) -> bool:

View File

@@ -6,7 +6,7 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License. # See the License for the specific language governing permissions and limitations under the License.
from typing import Dict, Iterable from typing import Dict, Iterable, Optional
import tabulate import tabulate
@@ -129,6 +129,7 @@ def render_feature(ostream, match: rd.Match, feature: frzf.Feature, indent=0):
ostream.write(" " * indent) ostream.write(" " * indent)
key = feature.type key = feature.type
value: Optional[str]
if isinstance(feature, frzf.BasicBlockFeature): if isinstance(feature, frzf.BasicBlockFeature):
# i don't think it makes sense to have standalone basic block features. # i don't think it makes sense to have standalone basic block features.
# we don't parse them from rules, only things like: `count(basic block) > 1` # we don't parse them from rules, only things like: `count(basic block) > 1`
@@ -140,7 +141,7 @@ def render_feature(ostream, match: rd.Match, feature: frzf.Feature, indent=0):
value = feature.class_ value = feature.class_
else: else:
# convert attributes to dictionary using aliased names, if applicable # convert attributes to dictionary using aliased names, if applicable
value = feature.dict(by_alias=True).get(key, None) value = feature.dict(by_alias=True).get(key)
if value is None: if value is None:
raise ValueError(f"{key} contains None") raise ValueError(f"{key} contains None")

View File

@@ -709,8 +709,7 @@ class Rule:
# note: we cannot recurse into the subscope sub-tree, # note: we cannot recurse into the subscope sub-tree,
# because its been replaced by a `match` statement. # because its been replaced by a `match` statement.
for child in statement.get_children(): for child in statement.get_children():
for new_rule in self._extract_subscope_rules_rec(child): yield from self._extract_subscope_rules_rec(child)
yield new_rule
def is_subscope_rule(self): def is_subscope_rule(self):
return bool(self.meta.get("capa/subscope-rule", False)) return bool(self.meta.get("capa/subscope-rule", False))
@@ -736,8 +735,7 @@ class Rule:
# replace old node with reference to new rule # replace old node with reference to new rule
# yield new rule # yield new rule
for new_rule in self._extract_subscope_rules_rec(self.statement): yield from self._extract_subscope_rules_rec(self.statement)
yield new_rule
def evaluate(self, features: FeatureSet, short_circuit=True): def evaluate(self, features: FeatureSet, short_circuit=True):
capa.perf.counters["evaluate.feature"] += 1 capa.perf.counters["evaluate.feature"] += 1

View File

@@ -74,6 +74,7 @@ dev = [
"flake8-no-implicit-concat==0.3.4", "flake8-no-implicit-concat==0.3.4",
"flake8-print==5.0.0", "flake8-print==5.0.0",
"flake8-todos==0.3.0", "flake8-todos==0.3.0",
"flake8-simplify==0.20.0",
"ruff==0.0.277", "ruff==0.0.277",
"black==23.3.0", "black==23.3.0",
"isort==5.11.4", "isort==5.11.4",

View File

@@ -397,7 +397,7 @@ def convert_rule(rule, rulename, cround, depth):
# this is "x or more". could be coded for strings TODO # this is "x or more". could be coded for strings TODO
return "BREAK", "Some aka x or more (TODO)", rule_comment, incomplete return "BREAK", "Some aka x or more (TODO)", rule_comment, incomplete
if s_type == "And" or s_type == "Or" or s_type == "Not" and not kid.name == "Some": if s_type == "And" or s_type == "Or" or s_type == "Not" and kid.name != "Some":
logger.info("doing bool with recursion: %r", kid) logger.info("doing bool with recursion: %r", kid)
logger.info("kid coming: %r", kid.name) logger.info("kid coming: %r", kid.name)
# logger.info("grandchildren: " + repr(kid.children)) # logger.info("grandchildren: " + repr(kid.children))

View File

@@ -72,7 +72,7 @@ def load_analysis(bv):
md5 = binaryninja.Transform["MD5"] md5 = binaryninja.Transform["MD5"]
rawhex = binaryninja.Transform["RawHex"] rawhex = binaryninja.Transform["RawHex"]
b = rawhex.encode(md5.encode(bv.parent_view.read(bv.parent_view.start, bv.parent_view.end))).decode("utf-8") b = rawhex.encode(md5.encode(bv.parent_view.read(bv.parent_view.start, bv.parent_view.end))).decode("utf-8")
if not a == b: if a != b:
binaryninja.log_error("sample mismatch") binaryninja.log_error("sample mismatch")
return -2 return -2

View File

@@ -279,7 +279,7 @@ class InvalidAttckOrMbcTechnique(Lint):
def check_rule(self, ctx: Context, rule: Rule): def check_rule(self, ctx: Context, rule: Rule):
for framework in self.enabled_frameworks: for framework in self.enabled_frameworks:
if framework in rule.meta.keys(): if framework in rule.meta:
for r in rule.meta[framework]: for r in rule.meta[framework]:
m = self.reg.match(r) m = self.reg.match(r)
if m is None: if m is None:
@@ -543,8 +543,7 @@ class FeatureNtdllNtoskrnlApi(Lint):
assert isinstance(feature.value, str) assert isinstance(feature.value, str)
modname, _, impname = feature.value.rpartition(".") modname, _, impname = feature.value.rpartition(".")
if modname == "ntdll": if modname == "ntdll" and impname in (
if impname in (
"LdrGetProcedureAddress", "LdrGetProcedureAddress",
"LdrLoadDll", "LdrLoadDll",
"NtCreateThread", "NtCreateThread",
@@ -574,8 +573,7 @@ class FeatureNtdllNtoskrnlApi(Lint):
# ntoskrnl.exe does not export these routines # ntoskrnl.exe does not export these routines
continue continue
if modname == "ntoskrnl": if modname == "ntoskrnl" and impname in (
if impname in (
"PsGetVersion", "PsGetVersion",
"PsLookupProcessByProcessId", "PsLookupProcessByProcessId",
"KeStackAttachProcess", "KeStackAttachProcess",

View File

@@ -50,8 +50,7 @@ def main():
for i in range(count): for i in range(count):
print(f"iteration {i+1}/{count}...") print(f"iteration {i+1}/{count}...")
with contextlib.redirect_stdout(io.StringIO()): with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
with contextlib.redirect_stderr(io.StringIO()):
t0 = time.time() t0 = time.time()
capa.main.main() capa.main.main()
t1 = time.time() t1 = time.time()

View File

@@ -7,6 +7,7 @@
# See the License for the specific language governing permissions and limitations under the License. # See the License for the specific language governing permissions and limitations under the License.
import textwrap import textwrap
import contextlib
import capa.rules import capa.rules
import capa.rules.cache import capa.rules.cache
@@ -74,10 +75,8 @@ def test_ruleset_cache_save_load():
cache_dir = capa.rules.cache.get_default_cache_directory() cache_dir = capa.rules.cache.get_default_cache_directory()
path = capa.rules.cache.get_cache_path(cache_dir, id) path = capa.rules.cache.get_cache_path(cache_dir, id)
try: with contextlib.suppress(OSError):
path.unlink() path.unlink()
except OSError:
pass
capa.rules.cache.cache_ruleset(cache_dir, rs) capa.rules.cache.cache_ruleset(cache_dir, rs)
assert path.exists() assert path.exists()
@@ -91,10 +90,8 @@ def test_ruleset_cache_invalid():
id = capa.rules.cache.compute_cache_identifier(content) id = capa.rules.cache.compute_cache_identifier(content)
cache_dir = capa.rules.cache.get_default_cache_directory() cache_dir = capa.rules.cache.get_default_cache_directory()
path = capa.rules.cache.get_cache_path(cache_dir, id) path = capa.rules.cache.get_cache_path(cache_dir, id)
try: with contextlib.suppress(OSError):
path.unlink() path.unlink()
except OSError:
pass
capa.rules.cache.cache_ruleset(cache_dir, rs) capa.rules.cache.cache_ruleset(cache_dir, rs)
assert path.exists() assert path.exists()