add wip proto translator using introspection

This commit is contained in:
Willi Ballenthin
2023-02-17 11:11:14 +01:00
parent 099cd868ae
commit c0ff0c2124
2 changed files with 175 additions and 15 deletions

View File

@@ -5,8 +5,8 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import io
import sys
import logging
from typing import Dict, Union
from dataclasses import dataclass
@@ -20,17 +20,26 @@ import capa.features.freeze.features
from capa.render.utils import StringIO
def emit_proto_enum(out: StringIO, enum):
# like: AddressType
title = enum["title"]
logger = logging.getLogger(__name__)
def is_enum(prop):
return "type" in prop and prop["type"] == "string" and "enum" in prop
def get_enum_name(prop):
return prop["title"]
def get_enum_value_name(enum, value):
# like: ADDRESSTYPE
prefix = title.upper()
prefix = get_enum_name(enum).upper()
def render_value(value):
# like: ADDRESSTYPE_ABSOLUTE
return "%s_%s" % (prefix, value.upper().replace(" ", "_"))
# like: ADDRESSTYPE_ABSOLUTE
return "%s_%s" % (prefix, value.upper().replace(" ", "_"))
def emit_proto_enum(out: StringIO, enum):
# like:
#
# enum AddressType {
@@ -39,10 +48,10 @@ def emit_proto_enum(out: StringIO, enum):
# ADDRESSTYPE_RELATIVE = 2;
# ...
# }
out.writeln(f"enum {title} {{")
out.writeln(f' {render_value("unspecified")} = 0;')
out.writeln(f"enum {get_enum_name(enum)} {{")
out.writeln(f' {get_enum_value_name(enum, "unspecified")} = 0;')
for i, value in enumerate(enum["enum"]):
out.writeln(f" {render_value(value)} = {i + 1};")
out.writeln(f" {get_enum_value_name(enum, value)} = {i + 1};")
out.writeln(f"}}")
out.writeln("")
@@ -63,7 +72,7 @@ def get_ref_type_name(prop):
def is_primitive_type(prop):
# things like: string, integer, bool, etc.
return "type" in prop and not prop["type"] == "object"
return "type" in prop and not prop["type"] == "object" and not "enum" in prop
def is_custom_type(prop):
@@ -207,6 +216,8 @@ def get_type_name(prop):
return get_custom_type_name(prop)
elif is_ref(prop):
return get_ref_type_name(prop)
elif is_enum(prop):
return get_enum_name(prop)
else:
raise NotImplementedError(prop)
@@ -442,3 +453,144 @@ def generate_proto() -> str:
and then commit the proto to the repo.
"""
return generate_proto_from_pydantic(pydantic.schema_of(capa.render.result_document.ResultDocument))
def int_to_pb2(v):
assert isinstance(v, int)
if v < -2_147_483_648:
raise ValueError("underflow")
if v > 0xFFFFFFFFFFFFFFFF:
raise ValueError("overflow")
if v < 0:
return capa.render.proto.capa_pb2.Integer(i=v)
else:
return capa.render.proto.capa_pb2.Integer(u=v)
def translate_to_pb2(schema, typ, src, dst):
logger.debug("translate: %s", get_type_name(typ))
if is_custom_type(typ):
for pname, ptyp in typ["properties"].items():
if is_union(ptyp):
logger.debug("translate: %s.%s (union)", get_type_name(typ), pname)
elif is_map(ptyp):
logger.debug("translate: %s.%s (map)", get_type_name(typ), pname)
else:
logger.debug("translate: %s.%s (%s)", get_type_name(typ), pname, get_type_name(ptyp))
psrc = getattr(src, pname)
if is_ref(ptyp):
logger.debug("resolving ref: %s", get_type_name(ptyp))
ptyp = schema["definitions"][get_ref_type_name(ptyp)]
if is_primitive_type(ptyp):
if ptyp["type"] == "string":
if "format" in ptyp and ptyp["format"] == "date-time":
pdst = psrc.isoformat("T") + "Z"
else:
pdst = psrc
setattr(dst, pname, pdst)
elif ptyp["type"] == "integer":
getattr(dst, pname).CopyFrom(int_to_pb2(psrc))
# TODO: move array out of primitives
elif is_array(ptyp):
vtyp = ptyp["items"]
if is_ref(vtyp):
logger.debug("resolving ref: %s", get_type_name(vtyp))
vtyp = schema["definitions"][get_ref_type_name(vtyp)]
if get_type_name(vtyp) == "string":
pdst = getattr(dst, pname)
for v in psrc:
pdst.append(v)
elif is_custom_type(vtyp):
pdst = getattr(dst, pname)
Dst = getattr(capa.render.proto.capa_pb2, get_type_name(vtyp))
for psrcv in psrc:
pdst = Dst()
translate_to_pb2(schema, vtyp, psrcv, pdst)
getattr(dst, pname).append(pdst)
else:
raise NotImplementedError(get_type_name(vtyp))
# TODO: move tuple out of primitives
elif is_tuple(ptyp):
raise NotImplementedError("tuple")
else:
raise NotImplementedError(ptyp["type"])
elif is_custom_type(ptyp):
ptyp = schema["definitions"][get_type_name(ptyp)]
Dst = getattr(capa.render.proto.capa_pb2, get_type_name(ptyp))
pdst = Dst()
translate_to_pb2(schema, ptyp, psrc, pdst)
# you can't just assign to a non-initialized composite field.
#
# https://stackoverflow.com/a/22771612/87207
getattr(dst, pname).CopyFrom(pdst)
elif is_enum(ptyp):
# like: AddressType
Enum = getattr(capa.render.proto.capa_pb2, get_type_name(ptyp))
# like: AddressType.ADDRESSTYPE_ABSOLUTE
v = getattr(Enum, get_enum_value_name(ptyp, psrc.value))
setattr(dst, pname, v)
elif is_tuple(ptyp):
raise NotImplementedError("tuple")
elif is_union(ptyp):
# in this scenario, we have a field that can be one of several types.
# in the proto message, we set *one* of many disjoint fields.
# they are named v0, v1, v2, etc. and not named after the type.
# so, we need to match up the types and resolve the destination field name.
# it is guaranteed that of the candidate fields, they each have a unique type.
# 1. resolve the name of the source type
ptypname = None
for candidate_type in ptyp["anyOf"]:
logger.debug("candidate: %s", get_type_name(candidate_type))
if get_type_name(candidate_type) == "Integer" and isinstance(psrc, int):
# special handling of numbers to account for range
ptypname = "Integer"
if not ptypname:
raise NotImplementedError(ptyp)
pdstname = None
for candidate_descriptor in dst.DESCRIPTOR.oneofs_by_name[pname].fields:
if candidate_descriptor.type == 11:
if candidate_descriptor.message_type.full_name == ptypname:
pdstname = candidate_descriptor.name
break
else:
raise NotImplementedError(candidate_descriptor.type)
if not pdstname:
raise NotImplementedError(ptypname)
if ptypname == "Integer":
getattr(dst, pdstname).CopyFrom(int_to_pb2(psrc))
else:
raise NotImplementedError(type(psrc))
else:
raise NotImplementedError(get_type_name(ptyp))
else:
raise NotImplementedError(get_type_name(typ))

View File

@@ -57,9 +57,17 @@ def test_generate_proto(tmp_path: pathlib.Path):
print("=====================================")
def test_translate_to_proto(pma0101_rd: ResultDocument):
def test_translate_to_pb2(pma0101_rd: ResultDocument):
schema = pydantic.schema_of(capa.render.result_document.ResultDocument)
src = pma0101_rd
dst = capa.render.proto.capa_pb2.ResultDocument()
typ = schema["definitions"]["ResultDocument"]
assert True
capa.render.proto.translate_to_pb2(schema, typ, src, dst)
print()
print("=====================================")
print(dst)
print("=====================================")
assert False