mirror of
https://github.com/immich-app/immich.git
synced 2025-12-11 07:11:05 -08:00
Compare commits
2 Commits
sqlite-flu
...
ml/tflite
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb0f79b72e | ||
|
|
5f6ad9e239 |
@@ -20,6 +20,7 @@ dependencies:
|
|||||||
- torchvision
|
- torchvision
|
||||||
- transformers==4.*
|
- transformers==4.*
|
||||||
- pip:
|
- pip:
|
||||||
- multilingual-clip
|
- multilingual-clip
|
||||||
- onnx-simplifier
|
- onnx-simplifier
|
||||||
|
- tensorflow==2.14.*
|
||||||
category: main
|
category: main
|
||||||
|
|||||||
72
machine-learning/export/models/tfclip.py
Normal file
72
machine-learning/export/models/tfclip.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from transformers import TFCLIPModel
|
||||||
|
|
||||||
|
from .util import ModelType, get_model_path
|
||||||
|
|
||||||
|
|
||||||
|
class _CLIPWrapper(tf.Module):
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
super(_CLIPWrapper)
|
||||||
|
self.model = TFCLIPModel.from_pretrained(model_name)
|
||||||
|
|
||||||
|
@tf.function()
|
||||||
|
def encode_image(self, input_tensor):
|
||||||
|
return self.model.get_image_features(input_tensor)
|
||||||
|
|
||||||
|
@tf.function()
|
||||||
|
def encode_text(self, input_tensor):
|
||||||
|
return self.model.get_text_features(input_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# exported model signatures use batch size 2 because of the following reasons:
|
||||||
|
# 1. ARM-NN cannot use dynamic batch sizes for complex models like CLIP ViT
|
||||||
|
# 2. batch size 1 creates a larger TF-Lite model that uses a lot (50%) more RAM
|
||||||
|
# 3. batch size 2 is ~50% faster on GPU than 1 while 4 (or larger) are not really faster
|
||||||
|
# 4. batch size >2 wastes more computation if only a single image is processed
|
||||||
|
BATCH_SIZE_IMAGE = 2
|
||||||
|
# On most small-scale systems there will only be one query at a time, no sense in batching
|
||||||
|
BATCH_SIZE_TEXT = 1
|
||||||
|
|
||||||
|
SIGNATURE_TEXT = "encode_text"
|
||||||
|
SIGNATURE_IMAGE = "encode_image"
|
||||||
|
|
||||||
|
|
||||||
|
def to_tflite(
|
||||||
|
model_name,
|
||||||
|
output_path_image: Path | str | None,
|
||||||
|
output_path_text: Path | str | None,
|
||||||
|
context_length: int = 77,
|
||||||
|
):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
_export_temporary_tf_model(model_name, tmpdir, context_length)
|
||||||
|
if output_path_image is not None:
|
||||||
|
image_path = get_model_path(output_path_image, ModelType.TFLITE)
|
||||||
|
_export_tflite_model(tmpdir, SIGNATURE_IMAGE, image_path.as_posix())
|
||||||
|
if output_path_text is not None:
|
||||||
|
text_path = get_model_path(output_path_text, ModelType.TFLITE)
|
||||||
|
_export_tflite_model(tmpdir, SIGNATURE_TEXT, text_path.as_posix())
|
||||||
|
|
||||||
|
|
||||||
|
def _export_temporary_tf_model(model_name, tmp_path: str, context_length: int):
|
||||||
|
wrapper = _CLIPWrapper(model_name)
|
||||||
|
conf = wrapper.model.config.vision_config
|
||||||
|
spec_visual = tf.TensorSpec(
|
||||||
|
shape=(BATCH_SIZE_IMAGE, conf.num_channels, conf.image_size, conf.image_size), dtype=tf.float32
|
||||||
|
)
|
||||||
|
encode_image = wrapper.encode_image.get_concrete_function(spec_visual)
|
||||||
|
spec_text = tf.TensorSpec(shape=(BATCH_SIZE_TEXT, context_length), dtype=tf.int32)
|
||||||
|
encode_text = wrapper.encode_text.get_concrete_function(spec_text)
|
||||||
|
signatures = {SIGNATURE_IMAGE: encode_image, SIGNATURE_TEXT: encode_text}
|
||||||
|
tf.saved_model.save(wrapper, tmp_path, signatures)
|
||||||
|
|
||||||
|
|
||||||
|
def _export_tflite_model(tmp_path: str, signature: str, output_path: str):
|
||||||
|
converter = tf.lite.TFLiteConverter.from_saved_model(tmp_path, signature_keys=[signature])
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
converter.target_spec.supported_types = [tf.float16]
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(tflite_model)
|
||||||
@@ -1,12 +1,18 @@
|
|||||||
import json
|
import json
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(output_dir: Path | str) -> Path:
|
class ModelType(Enum):
|
||||||
|
ONNX = "onnx"
|
||||||
|
TFLITE = "tflite"
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(output_dir: Path | str, model_type: ModelType = ModelType.ONNX) -> Path:
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return output_dir / "model.onnx"
|
return output_dir / f"model.{model_type.value}"
|
||||||
|
|
||||||
|
|
||||||
def save_config(config: Any, output_path: Path | str) -> None:
|
def save_config(config: Any, output_path: Path | str) -> None:
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ from pathlib import Path
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
from huggingface_hub import create_repo, login, upload_folder
|
from huggingface_hub import create_repo, login, upload_folder
|
||||||
from models import mclip, openclip
|
|
||||||
from rich.progress import Progress
|
from rich.progress import Progress
|
||||||
|
|
||||||
|
from models import mclip, openclip, tfclip
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
"RN50::openai",
|
"RN50::openai",
|
||||||
"RN50::yfcc15m",
|
"RN50::yfcc15m",
|
||||||
@@ -37,9 +38,10 @@ models = [
|
|||||||
"M-CLIP/XLM-Roberta-Large-Vit-B-32",
|
"M-CLIP/XLM-Roberta-Large-Vit-B-32",
|
||||||
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
|
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
|
||||||
"M-CLIP/XLM-Roberta-Large-Vit-L-14",
|
"M-CLIP/XLM-Roberta-Large-Vit-L-14",
|
||||||
|
"openai/clip-vit-base-patch32",
|
||||||
]
|
]
|
||||||
|
|
||||||
login(token=os.environ["HF_AUTH_TOKEN"])
|
# login(token=os.environ["HF_AUTH_TOKEN"])
|
||||||
|
|
||||||
with Progress() as progress:
|
with Progress() as progress:
|
||||||
task1 = progress.add_task("[green]Exporting models...", total=len(models))
|
task1 = progress.add_task("[green]Exporting models...", total=len(models))
|
||||||
@@ -65,6 +67,8 @@ with Progress() as progress:
|
|||||||
textual_dir = tmpdir / model_name / "textual"
|
textual_dir = tmpdir / model_name / "textual"
|
||||||
if model.startswith("M-CLIP"):
|
if model.startswith("M-CLIP"):
|
||||||
mclip.to_onnx(model, visual_dir, textual_dir)
|
mclip.to_onnx(model, visual_dir, textual_dir)
|
||||||
|
elif "/" in model:
|
||||||
|
tfclip.to_tflite(model, visual_dir.as_posix(), textual_dir.as_posix())
|
||||||
else:
|
else:
|
||||||
name, _, pretrained = model_name.partition("__")
|
name, _, pretrained = model_name.partition("__")
|
||||||
openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)
|
openclip.to_onnx(openclip.OpenCLIPModelConfig(name, pretrained), visual_dir, textual_dir)
|
||||||
|
|||||||
36
machine-learning/poetry.lock
generated
36
machine-learning/poetry.lock
generated
@@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiocache"
|
name = "aiocache"
|
||||||
@@ -3882,6 +3882,30 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mpmath = ">=0.19"
|
mpmath = ">=0.19"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tflite-runtime"
|
||||||
|
version = "2.14.0"
|
||||||
|
description = "TensorFlow Lite is for mobile and embedded devices."
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "tflite_runtime-2.14.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:bb11df4283e281cd609c621ac9470ad0cb5674408593272d7593a2c6bde8a808"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp310-cp310-manylinux_2_34_aarch64.whl", hash = "sha256:d38c6885f5e9673c11a61ccec5cad7c032ab97340718d26b17794137f398b780"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp310-cp310-manylinux_2_34_armv7l.whl", hash = "sha256:7fe33f763263d1ff2733a09945a7547ab063d8bc311fd2a1be8144d850016ad3"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:195ab752e7e57329a68e54dd3dd5439fad888b9bff1be0f0dc042a3237a90e4d"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp311-cp311-manylinux_2_34_aarch64.whl", hash = "sha256:ce9fa5d770a9725c746dcbf6f59f3178233b3759f09982e8b2db8d2234c333b0"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp311-cp311-manylinux_2_34_armv7l.whl", hash = "sha256:c4e66a74165b18089c86788400af19fa551768ac782d231a9beae2f6434f7949"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:9f965054467f7890e678943858c6ac76a5197b17f61b48dcbaaba0af41d541a7"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp38-cp38-manylinux_2_34_aarch64.whl", hash = "sha256:437167fe3d8b12f50f5d694da8f45d268ab84a495e24c3dd810e02e1012125de"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp38-cp38-manylinux_2_34_armv7l.whl", hash = "sha256:79d8e17f68cc940df7e68a177b22dda60fcffba195fb9dd908d03724d65fd118"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4aa740210a0fd9e4db4a46e9778914846b136e161525681b41575ca4896158fb"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp39-cp39-manylinux_2_34_aarch64.whl", hash = "sha256:be198b7dc4401204be54a15884d9e336389790eb707439524540f5a9329fdd02"},
|
||||||
|
{file = "tflite_runtime-2.14.0-cp39-cp39-manylinux_2_34_armv7l.whl", hash = "sha256:eca7672adca32727bbf5c0f1caf398fc17bbe222f2a684c7a2caea6fc6767203"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = ">=1.23.2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "threadpoolctl"
|
name = "threadpoolctl"
|
||||||
version = "3.2.0"
|
version = "3.2.0"
|
||||||
@@ -4025,6 +4049,14 @@ dev = ["tokenizers[testing]"]
|
|||||||
docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"]
|
docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"]
|
||||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "torch"
|
||||||
|
version = "2.0.1"
|
||||||
|
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = []
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torch"
|
name = "torch"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@@ -4772,4 +4804,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "bba5f87aa67bc1d2283a9f4b471ef78e572337f22413870d324e908014410d53"
|
content-hash = "56614afdeeeec3b7f0b786771a8fcc126761c882b1033664056042833767e521"
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ python-multipart = "^0.0.6"
|
|||||||
orjson = "^3.9.5"
|
orjson = "^3.9.5"
|
||||||
safetensors = "0.3.2"
|
safetensors = "0.3.2"
|
||||||
gunicorn = "^21.1.0"
|
gunicorn = "^21.1.0"
|
||||||
|
tflite-runtime = "^2.14.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
mypy = "^1.3.0"
|
mypy = "^1.3.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user