mirror of
https://github.com/immich-app/immich.git
synced 2026-02-04 19:12:11 -08:00
35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import onnxruntime as ort
|
|
from numpy.typing import NDArray
|
|
from onnx.shape_inference import infer_shapes
|
|
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
|
|
|
|
|
def ort_has_batch_dim(session: ort.InferenceSession) -> bool:
|
|
return session.get_inputs()[0].shape[0] == "batch"
|
|
|
|
|
|
def ort_squeeze_outputs(session: ort.InferenceSession) -> None:
|
|
original_run = session.run
|
|
|
|
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
|
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
|
|
out = [o.squeeze(axis=0) for o in out]
|
|
return out
|
|
|
|
session.run = run
|
|
|
|
|
|
def ort_add_batch_dim(input_path: Path, output_path: Path) -> None:
|
|
proto = onnx.load(input_path)
|
|
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
|
|
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
|
|
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
|
|
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
|
|
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
|
|
inferred = infer_shapes(updated_proto)
|
|
onnx.save(inferred, output_path)
|