Files
immich/machine-learning/app/models/session.py
mertalev 259386cf13 refactor
2024-06-08 21:24:23 -04:00

35 lines
1.3 KiB
Python

from pathlib import Path
import numpy as np
import onnx
import onnxruntime as ort
from numpy.typing import NDArray
from onnx.shape_inference import infer_shapes
from onnx.tools.update_model_dims import update_inputs_outputs_dims
def ort_has_batch_dim(session: ort.InferenceSession) -> bool:
return session.get_inputs()[0].shape[0] == "batch"
def ort_squeeze_outputs(session: ort.InferenceSession) -> None:
original_run = session.run
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
out = [o.squeeze(axis=0) for o in out]
return out
session.run = run
def ort_add_batch_dim(input_path: Path, output_path: Path) -> None:
proto = onnx.load(input_path)
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
inferred = infer_shapes(updated_proto)
onnx.save(inferred, output_path)