capafmt: use yaml parser that supports comments to reformat

This commit is contained in:
William Ballenthin
2020-06-21 11:53:15 -06:00
parent 56536792f8
commit 3bc6c5805f
2 changed files with 23 additions and 77 deletions

View File

@@ -1,9 +1,12 @@
import yaml
import uuid import uuid
import codecs import codecs
import logging import logging
import binascii import binascii
import six
import yaml
import ruamel.yaml
import capa.engine import capa.engine
from capa.engine import * from capa.engine import *
import capa.features import capa.features
@@ -509,13 +512,27 @@ class Rule(object):
raise InvalidRuleWithPath(path, str(e)) raise InvalidRuleWithPath(path, str(e))
def to_yaml(self): def to_yaml(self):
import six # reformat the yaml document with a common style.
from ruamel.yaml import YAML # this includes:
# - ordering the meta elements
# - indenting the nested items with two spaces
#
# we use the ruamel.yaml parser for this, because it supports roundtripping of documents with comments.
# order the meta elements in the following preferred order.
# any custom keys will come after this.
COMMON_KEYS = ("name", "namespace", "rule-category", "author", "att&ck", "mbc", "examples", "scope") COMMON_KEYS = ("name", "namespace", "rule-category", "author", "att&ck", "mbc", "examples", "scope")
yaml = YAML(typ='rt') yaml = ruamel.yaml.YAML(typ='rt')
# use block mode, not inline json-like mode
yaml.default_flow_style = False yaml.default_flow_style = False
# indent lists by two spaces below their parent
#
# features:
# - or:
# - mnemonic: aesdec
# - mnemonic: vaesdec
yaml.indent(sequence=2, offset=2)
definition = yaml.load(self.definition) definition = yaml.load(self.definition)
# definition retains a reference to `meta`, # definition retains a reference to `meta`,
@@ -541,79 +558,7 @@ class Rule(object):
ostream = six.BytesIO() ostream = six.BytesIO()
yaml.dump(definition, ostream) yaml.dump(definition, ostream)
print(ostream.getvalue().decode('utf-8')) return ostream.getvalue().decode('utf-8').rstrip("\n")
return ''
definition = yaml.safe_load(self.definition)
formatted = DefaultOrderedDict(default_factory=lambda: DefaultOrderedDict(default_factory=DefaultOrderedDict))
meta = definition["rule"]["meta"]
for key in COMMON_KEYS:
if key in meta:
formatted["rule"]["meta"][key] = meta[key]
for key in sorted(meta.keys()):
if key in COMMON_KEYS:
continue
formatted["rule"]["meta"][key] = meta[key]
formatted["rule"]["features"] = definition["rule"]["features"]
return yaml.dump(formatted, Dumper=CapaDumper, default_flow_style=False)
class DefaultOrderedDict(collections.OrderedDict):
# Source: http://stackoverflow.com/a/6190500/562769
def __init__(self, default_factory=None, *a, **kw):
if (default_factory is not None and not isinstance(default_factory, collections.Callable)):
raise TypeError('first argument must be callable')
super(DefaultOrderedDict, self).__init__(*a, **kw)
self.default_factory = default_factory
def __getitem__(self, key):
try:
return super(DefaultOrderedDict, self).__getitem__(key)
except KeyError:
return self.__missing__(key)
def __missing__(self, key):
if self.default_factory is None:
raise KeyError(key)
value = self.default_factory()
self[key] = value
return value
class CapaDumper(yaml.Dumper):
'''
Tweak the yaml serializer to emit sequences/lists with additional indentation.
ref: https://stackoverflow.com/a/39681672/87207
before:
rule:
features:
- or:
- count(mnemonic(rdtsc)): 2 or more
- mnemonic: icebp
after:
rule:
features:
- or:
- count(mnemonic(rdtsc)): 2 or more
- mnemonic: icebp
'''
def __init__(self, *args, **kwargs):
super(CapaDumper, self).__init__(*args, **kwargs)
self.add_representer(DefaultOrderedDict, lambda dumper, data: dumper.represent_dict(data.iteritems()))
def increase_indent(self, flow=False, indentless=False):
return super(CapaDumper, self).increase_indent(flow, False)
def get_rules_with_scope(rules, scope): def get_rules_with_scope(rules, scope):

View File

@@ -9,6 +9,7 @@ requirements = [
"tqdm", "tqdm",
"pyyaml", "pyyaml",
"tabulate", "tabulate",
"ruamel.yaml"
] ]
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):