mirror of
https://github.com/immich-app/immich.git
synced 2026-01-26 03:14:39 -08:00
476 lines
26 KiB
Python
476 lines
26 KiB
Python
import os
|
|
import platform
|
|
import subprocess
|
|
from typing import Callable, ClassVar
|
|
|
|
import onnx
|
|
from onnx_graphsurgeon import Constant, Node, Variable, import_onnx, export_onnx
|
|
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
|
from huggingface_hub import snapshot_download
|
|
from onnx.shape_inference import infer_shapes_path
|
|
from huggingface_hub import login, upload_file
|
|
import onnx2tf
|
|
import numpy as np
|
|
import onnxsim
|
|
from shutil import rmtree
|
|
|
|
# hack: changed Mul op in onnx2tf to skip broadcast if graph_node.o().op == 'Sigmoid'
|
|
|
|
# i can explain
|
|
# armnn only supports up to 4d tranposes, but the model has a 5d transpose due to a redundant unsqueeze
|
|
# this function folds the unsqueeze+transpose+squeeze into a single 4d transpose
|
|
# it also switches from gather ops to slices since armnn has different dimension semantics for gathers
|
|
# also fixes batch normalization being in training mode
|
|
def make_onnx_armnn_compatible(model_path: str):
|
|
proto = onnx.load(model_path)
|
|
graph = import_onnx(proto)
|
|
|
|
gather_idx = 1
|
|
squeeze_idx = 1
|
|
for node in graph.nodes:
|
|
for link1 in node.outputs:
|
|
if "Unsqueeze" in link1.name:
|
|
for node1 in link1.outputs:
|
|
for link2 in node1.outputs:
|
|
if "Transpose" in link2.name:
|
|
for node2 in link2.outputs:
|
|
if node2.attrs.get("perm") == [3, 1, 2, 0, 4]:
|
|
node2.attrs["perm"] = [2, 0, 1, 3]
|
|
link2.shape = link1.shape
|
|
for link3 in node2.outputs:
|
|
if "Squeeze" in link3.name:
|
|
link3.shape = [link3.shape[x] for x in [0, 1, 2, 4]]
|
|
for node3 in link3.outputs:
|
|
for link4 in node3.outputs:
|
|
link4.shape = link3.shape
|
|
try:
|
|
idx = link2.inputs.index(node1)
|
|
link2.inputs[idx] = node
|
|
except ValueError:
|
|
pass
|
|
|
|
node.outputs = [link2]
|
|
if "Gather" in link4.name:
|
|
for node4 in link4.outputs:
|
|
axis = node1.attrs.get("axis", 0)
|
|
index = node4.inputs[1].values
|
|
slice_link = Variable(
|
|
f"onnx::Slice_123{gather_idx}",
|
|
dtype=link4.dtype,
|
|
shape=[1] + link3.shape[1:],
|
|
)
|
|
slice_node = Node(
|
|
op="Slice",
|
|
inputs=[
|
|
link3,
|
|
Constant(
|
|
f"SliceStart_123{gather_idx}",
|
|
np.array([index]),
|
|
),
|
|
Constant(
|
|
f"SliceEnd_123{gather_idx}",
|
|
np.array([index + 1]),
|
|
),
|
|
Constant(
|
|
f"SliceAxis_123{gather_idx}",
|
|
np.array([axis]),
|
|
),
|
|
],
|
|
outputs=[slice_link],
|
|
name=f"Slice_123{gather_idx}",
|
|
)
|
|
graph.nodes.append(slice_node)
|
|
gather_idx += 1
|
|
|
|
for link5 in node4.outputs:
|
|
for node5 in link5.outputs:
|
|
try:
|
|
idx = node5.inputs.index(link5)
|
|
node5.inputs[idx] = slice_link
|
|
except ValueError:
|
|
pass
|
|
elif node.op == "LayerNormalization":
|
|
for node1 in link1.outputs:
|
|
if node1.op == "Gather":
|
|
for link2 in node1.outputs:
|
|
for node2 in link2.outputs:
|
|
axis = node1.attrs.get("axis", 0)
|
|
index = node1.inputs[1].values
|
|
slice_link = Variable(
|
|
f"onnx::Slice_123{gather_idx}",
|
|
dtype=link2.dtype,
|
|
shape=[1] + link2.shape,
|
|
)
|
|
slice_node = Node(
|
|
op="Slice",
|
|
inputs=[
|
|
node1.inputs[0],
|
|
Constant(
|
|
f"SliceStart_123{gather_idx}",
|
|
np.array([index]),
|
|
),
|
|
Constant(
|
|
f"SliceEnd_123{gather_idx}",
|
|
np.array([index + 1]),
|
|
),
|
|
Constant(
|
|
f"SliceAxis_123{gather_idx}",
|
|
np.array([axis]),
|
|
),
|
|
],
|
|
outputs=[slice_link],
|
|
name=f"Slice_123{gather_idx}",
|
|
)
|
|
graph.nodes.append(slice_node)
|
|
gather_idx += 1
|
|
|
|
squeeze_link = Variable(
|
|
f"onnx::Squeeze_123{squeeze_idx}",
|
|
dtype=link2.dtype,
|
|
shape=link2.shape,
|
|
)
|
|
squeeze_node = Node(
|
|
op="Squeeze",
|
|
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
|
|
outputs=[squeeze_link],
|
|
name=f"Squeeze_123{squeeze_idx}",
|
|
)
|
|
graph.nodes.append(squeeze_node)
|
|
squeeze_idx += 1
|
|
try:
|
|
idx = node2.inputs.index(link2)
|
|
node2.inputs[idx] = squeeze_link
|
|
except ValueError:
|
|
pass
|
|
elif node.op == "Reshape":
|
|
for node1 in link1.outputs:
|
|
if node1.op == "Gather":
|
|
node2s = [n for l in node1.outputs for n in l.outputs]
|
|
if any(n.op == "Abs" for n in node2s):
|
|
axis = node1.attrs.get("axis", 0)
|
|
index = node1.inputs[1].values
|
|
slice_link = Variable(
|
|
f"onnx::Slice_123{gather_idx}",
|
|
dtype=node1.outputs[0].dtype,
|
|
shape=[1] + node1.outputs[0].shape,
|
|
)
|
|
slice_node = Node(
|
|
op="Slice",
|
|
inputs=[
|
|
node1.inputs[0],
|
|
Constant(
|
|
f"SliceStart_123{gather_idx}",
|
|
np.array([index]),
|
|
),
|
|
Constant(
|
|
f"SliceEnd_123{gather_idx}",
|
|
np.array([index + 1]),
|
|
),
|
|
Constant(
|
|
f"SliceAxis_123{gather_idx}",
|
|
np.array([axis]),
|
|
),
|
|
],
|
|
outputs=[slice_link],
|
|
name=f"Slice_123{gather_idx}",
|
|
)
|
|
graph.nodes.append(slice_node)
|
|
gather_idx += 1
|
|
|
|
squeeze_link = Variable(
|
|
f"onnx::Squeeze_123{squeeze_idx}",
|
|
dtype=node1.outputs[0].dtype,
|
|
shape=node1.outputs[0].shape,
|
|
)
|
|
squeeze_node = Node(
|
|
op="Squeeze",
|
|
inputs=[slice_link, Constant(f"SqueezeAxis_123{squeeze_idx}",np.array([0]),)],
|
|
outputs=[squeeze_link],
|
|
name=f"Squeeze_123{squeeze_idx}",
|
|
)
|
|
graph.nodes.append(squeeze_node)
|
|
squeeze_idx += 1
|
|
for node2 in node2s:
|
|
node2.inputs[0] = squeeze_link
|
|
elif node.op == "BatchNormalization":
|
|
if node.attrs.get("training_mode") == 1:
|
|
node.attrs["training_mode"] = 0
|
|
node.outputs = node.outputs[:1]
|
|
|
|
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
|
graph.toposort()
|
|
graph.fold_constants()
|
|
updated = export_onnx(graph)
|
|
onnx.save(updated, model_path)
|
|
# infer_shapes_path(updated, check_type=True, strict_mode=False, data_prop=True)
|
|
|
|
# for some reason, reloading the model is necessary to apply the correct shape
|
|
proto = onnx.load(model_path)
|
|
graph = import_onnx(proto)
|
|
for node in graph.nodes:
|
|
if node.op == "Slice":
|
|
for link in node.outputs:
|
|
if "Slice_123" in link.name and link.shape[0] == 3:
|
|
link.shape[0] = 1
|
|
|
|
graph.cleanup(remove_unused_node_outputs=True, recurse_subgraphs=True, recurse_functions=True)
|
|
graph.toposort()
|
|
graph.fold_constants()
|
|
updated = export_onnx(graph)
|
|
onnx.save(updated, model_path)
|
|
infer_shapes_path(model_path, check_type=True, strict_mode=True, data_prop=True)
|
|
|
|
|
|
def onnx_make_fixed(input_path: str, output_path: str, input_shape: tuple[int, ...]):
|
|
simplified, success = onnxsim.simplify(input_path)
|
|
if not success:
|
|
raise RuntimeError(f"Failed to simplify {input_path}")
|
|
try:
|
|
onnx.save(simplified, output_path)
|
|
except:
|
|
onnx.save(simplified, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
|
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
|
model = onnx.load_model(output_path)
|
|
make_input_shape_fixed(model.graph, model.graph.input[0].name, input_shape)
|
|
fix_output_shapes(model)
|
|
try:
|
|
onnx.save(model, output_path)
|
|
except:
|
|
onnx.save(model, output_path, save_as_external_data=True, all_tensors_to_one_file=False)
|
|
onnx.save(model, output_path)
|
|
infer_shapes_path(output_path, check_type=True, strict_mode=True, data_prop=True)
|
|
|
|
|
|
class ExportBase:
|
|
task: ClassVar[str]
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
input_shape: tuple[int, ...],
|
|
pretrained: str | None = None,
|
|
optimization_level: int = 5,
|
|
):
|
|
super().__init__()
|
|
self.name = name
|
|
self.optimize = optimization_level
|
|
self.input_shape = input_shape
|
|
self.pretrained = pretrained
|
|
self.cache_dir = os.path.join(os.environ["CACHE_DIR"], self.model_name)
|
|
|
|
def download(self) -> str:
|
|
model_path = os.path.join(self.cache_dir, self.task, "model.onnx")
|
|
if not os.path.isfile(model_path):
|
|
print(f"Downloading {self.model_name}...")
|
|
snapshot_download(self.repo_name, cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
|
|
return model_path
|
|
|
|
def to_onnx_static(self) -> str:
|
|
onnx_path_original = self.download()
|
|
static_dir = os.path.join(self.cache_dir, self.task, "static")
|
|
os.makedirs(static_dir, exist_ok=True)
|
|
|
|
static_path = os.path.join(static_dir, "model.onnx")
|
|
if not os.path.isfile(static_path):
|
|
print(f"Making {self.model_name} ({self.task}) static")
|
|
onnx_make_fixed(onnx_path_original, static_path, self.input_shape)
|
|
make_onnx_armnn_compatible(static_path)
|
|
static_model = onnx.load_model(static_path)
|
|
self.inputs = [input_.name for input_ in static_model.graph.input]
|
|
self.outputs = [output_.name for output_ in static_model.graph.output]
|
|
return static_path
|
|
|
|
def to_tflite(self, output_dir: str) -> tuple[str, str]:
|
|
input_path = self.to_onnx_static()
|
|
tflite_fp32 = os.path.join(output_dir, "model_float32.tflite")
|
|
tflite_fp16 = os.path.join(output_dir, "model_float16.tflite")
|
|
if not os.path.isfile(tflite_fp32) or not os.path.isfile(tflite_fp16):
|
|
print(f"Exporting {self.model_name} ({self.task}) to TFLite (this might take a few minutes)")
|
|
onnx2tf.convert(
|
|
input_onnx_file_path=input_path,
|
|
output_folder_path=output_dir,
|
|
keep_shape_absolutely_input_names=self.inputs,
|
|
verbosity="warn",
|
|
copy_onnx_input_output_names_to_tflite=True,
|
|
output_signaturedefs=True,
|
|
)
|
|
|
|
return tflite_fp32, tflite_fp16
|
|
|
|
def to_armnn(self, output_dir: str) -> tuple[str, str]:
|
|
output_dir = os.path.abspath(output_dir)
|
|
tflite_model_dir = os.path.join(output_dir, "tflite")
|
|
tflite_fp32, tflite_fp16 = self.to_tflite(tflite_model_dir)
|
|
|
|
fp16_dir = os.path.join(output_dir, "fp16")
|
|
os.makedirs(fp16_dir, exist_ok=True)
|
|
armnn_fp32 = os.path.join(output_dir, "model.armnn")
|
|
armnn_fp16 = os.path.join(fp16_dir, "model.armnn")
|
|
|
|
args = ["./armnnconverter", "-f", "tflite-binary"]
|
|
args.append("-i")
|
|
args.extend(self.inputs)
|
|
args.append("-o")
|
|
args.extend(self.outputs)
|
|
|
|
fp32_args = args.copy()
|
|
fp32_args.extend(["-m", tflite_fp32, "-p", armnn_fp32])
|
|
|
|
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp32 precision")
|
|
try:
|
|
print(subprocess.check_output(fp32_args, stderr=subprocess.STDOUT).decode())
|
|
except subprocess.CalledProcessError as e:
|
|
print(e.output.decode())
|
|
try:
|
|
rmtree(tflite_model_dir, ignore_errors=True)
|
|
finally:
|
|
raise e
|
|
print(f"Finished exporting {self.model_name} ({self.task}) with fp32 precision")
|
|
|
|
fp16_args = args.copy()
|
|
fp16_args.extend(["-m", tflite_fp16, "-p", armnn_fp16])
|
|
|
|
print(f"Exporting {self.model_name} ({self.task}) to ARM NN with fp16 precision")
|
|
try:
|
|
print(subprocess.check_output(fp16_args, stderr=subprocess.STDOUT).decode())
|
|
except subprocess.CalledProcessError as e:
|
|
print(e.output.decode())
|
|
try:
|
|
rmtree(tflite_model_dir, ignore_errors=True)
|
|
finally:
|
|
raise e
|
|
print(f"Finished exporting {self.model_name} ({self.task}) with fp16 precision")
|
|
|
|
return armnn_fp32, armnn_fp16
|
|
|
|
@property
|
|
def model_name(self) -> str:
|
|
return f"{self.name}__{self.pretrained}" if self.pretrained else self.name
|
|
|
|
@property
|
|
def repo_name(self) -> str:
|
|
return f"immich-app/{self.model_name}"
|
|
|
|
class ArcFace(ExportBase):
|
|
task = "recognition"
|
|
|
|
|
|
class RetinaFace(ExportBase):
|
|
task = "detection"
|
|
|
|
|
|
class OpenClipVisual(ExportBase):
|
|
task = "visual"
|
|
|
|
|
|
class OpenClipTextual(ExportBase):
|
|
task = "textual"
|
|
|
|
|
|
class MClipTextual(ExportBase):
|
|
task = "textual"
|
|
|
|
|
|
def main() -> None:
|
|
if platform.machine() not in ("x86_64", "AMD64"):
|
|
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
|
|
hf_token = os.environ.get("HF_AUTH_TOKEN")
|
|
if hf_token:
|
|
login(token=hf_token)
|
|
os.environ["LD_LIBRARY_PATH"] = "armnn"
|
|
failed: list[Callable[[], ExportBase]] = [
|
|
lambda: OpenClipVisual("ViT-H-14-378-quickgelu", (1, 3, 378, 378), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
|
|
lambda: OpenClipVisual("ViT-H-14-quickgelu", (1, 3, 224, 224), pretrained="dfn5b"), # flatbuffers: cannot grow buffer beyond 2 gigabytes (will probably work with fp16)
|
|
lambda: OpenClipVisual("ViT-H-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b79k"),
|
|
lambda: OpenClipTextual("ViT-H-14", (1, 77), pretrained="laion2b-s32b-b79k"),
|
|
lambda: OpenClipVisual("ViT-g-14", (1, 3, 224, 224), pretrained="laion2b-s12b-b42k"),
|
|
lambda: OpenClipTextual("ViT-g-14", (1, 77), pretrained="laion2b-s12b-b42k"),
|
|
lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-16Plus", (1, 3, 240, 240)),
|
|
lambda: OpenClipVisual("XLM-Roberta-Large-ViT-H-14", (1, 3, 224, 224), pretrained="frozen_laion5b_s13b_b90k"),
|
|
lambda: MClipTextual("XLM-Roberta-Large-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
lambda: MClipTextual("XLM-Roberta-Large-Vit-B-16Plus", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
lambda: MClipTextual("LABSE-Vit-L-14", (1, 77)), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
lambda: OpenClipTextual("XLM-Roberta-Large-ViT-H-14", (1, 77), pretrained="frozen_laion5b_s13b_b90k"), # Expected normalized_shape to be at least 1-dimensional, i.e., containing at least one element, but got normalized_shape = []
|
|
]
|
|
|
|
oom = [
|
|
lambda: OpenClipVisual("nllb-clip-base-siglip", (1, 3, 384, 384), pretrained="v1"),
|
|
lambda: OpenClipTextual("nllb-clip-base-siglip", (1, 77), pretrained="v1"),
|
|
lambda: OpenClipVisual("nllb-clip-large-siglip", (1, 3, 384, 384), pretrained="v1"),
|
|
lambda: OpenClipTextual("nllb-clip-large-siglip", (1, 77), pretrained="v1"), # ERROR (tinynn.converter.base) Unsupported ops: aten::logical_not
|
|
# lambda: OpenClipTextual("ViT-H-14-quickgelu", (1, 77), pretrained="dfn5b"),
|
|
# lambda: OpenClipTextual("ViT-H-14-378-quickgelu", (1, 77), pretrained="dfn5b"),
|
|
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-L-14", (1, 3, 224, 224)),
|
|
]
|
|
|
|
succeeded: list[Callable[[], ExportBase]] = [
|
|
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b_e16"),
|
|
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b_e16"),
|
|
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
|
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e31"),
|
|
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
|
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion400m_e32"),
|
|
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="laion2b-s34b-b79k"),
|
|
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="laion2b-s34b-b79k"),
|
|
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
|
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e31"),
|
|
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
|
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="laion400m_e32"),
|
|
# lambda: OpenClipVisual("ViT-B-16-plus-240", (1, 3, 240, 240), pretrained="laion400m_e31"),
|
|
# lambda: OpenClipTextual("ViT-B-16-plus-240", (1, 77), pretrained="laion400m_e31"),
|
|
# lambda: OpenClipVisual("ViT-B-32", (1, 3, 224, 224), pretrained="openai"),
|
|
# lambda: OpenClipTextual("ViT-B-32", (1, 77), pretrained="openai"),
|
|
# lambda: OpenClipVisual("ViT-B-16", (1, 3, 224, 224), pretrained="openai"),
|
|
# lambda: OpenClipTextual("ViT-B-16", (1, 77), pretrained="openai"),
|
|
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="openai"),
|
|
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="openai"),
|
|
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="yfcc15m"),
|
|
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="yfcc15m"),
|
|
# lambda: OpenClipVisual("RN50", (1, 3, 224, 224), pretrained="cc12m"),
|
|
# lambda: OpenClipTextual("RN50", (1, 77), pretrained="cc12m"),
|
|
# lambda: OpenClipVisual("XLM-Roberta-Large-Vit-B-32", (1, 3, 224, 224)),
|
|
# lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="openai"),
|
|
# lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="openai"),
|
|
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e31"),
|
|
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e31"),
|
|
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion400m_e32"),
|
|
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion400m_e32"),
|
|
lambda: OpenClipVisual("ViT-L-14", (1, 3, 224, 224), pretrained="laion2b-s32b-b82k"),
|
|
lambda: OpenClipTextual("ViT-L-14", (1, 77), pretrained="laion2b-s32b-b82k"),
|
|
# lambda: OpenClipVisual("ViT-L-14-336", (1, 3, 336, 336), pretrained="openai"),
|
|
# lambda: OpenClipTextual("ViT-L-14-336", (1, 77), pretrained="openai"),
|
|
# lambda: ArcFace("buffalo_s", (1, 3, 112, 112), optimization_level=3),
|
|
# lambda: RetinaFace("buffalo_s", (1, 3, 640, 640), optimization_level=3),
|
|
# lambda: ArcFace("buffalo_m", (1, 3, 112, 112), optimization_level=3),
|
|
# lambda: RetinaFace("buffalo_m", (1, 3, 640, 640), optimization_level=3),
|
|
# lambda: ArcFace("buffalo_l", (1, 3, 112, 112), optimization_level=3),
|
|
# lambda: RetinaFace("buffalo_l", (1, 3, 640, 640), optimization_level=3),
|
|
# lambda: ArcFace("antelopev2", (1, 3, 112, 112), optimization_level=3),
|
|
# lambda: RetinaFace("antelopev2", (1, 3, 640, 640), optimization_level=3),
|
|
]
|
|
|
|
models: list[Callable[[], ExportBase]] = [*failed, *succeeded]
|
|
for _model in succeeded:
|
|
model = _model()
|
|
try:
|
|
model_dir = os.path.join("output", model.model_name)
|
|
output_dir = os.path.join(model_dir, model.task)
|
|
armnn_fp32, armnn_fp16 = model.to_armnn(output_dir)
|
|
relative_fp32 = os.path.relpath(armnn_fp32, start=model_dir)
|
|
relative_fp16 = os.path.relpath(armnn_fp16, start=model_dir)
|
|
if hf_token and os.path.isfile(armnn_fp32):
|
|
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
|
|
upload_file(path_or_fileobj=armnn_fp32, path_in_repo=relative_fp32, repo_id=model.repo_name)
|
|
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp32 precision")
|
|
if hf_token and os.path.isfile(armnn_fp16):
|
|
print(f"Uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
|
|
upload_file(path_or_fileobj=armnn_fp16, path_in_repo=relative_fp16, repo_id=model.repo_name)
|
|
print(f"Finished uploading {model.model_name} ({model.task}) ARM NN model with fp16 precision")
|
|
except Exception as exc:
|
|
print(f"Failed to export {model.model_name} ({model.task}): {exc}")
|
|
raise exc
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|