mirror of
https://github.com/immich-app/immich.git
synced 2026-06-12 19:11:52 -07:00
fix(ml): stabilize MIGraphX inference (#28444)
* fix: stabilize ROCm MIGraphX inference Serialize MIGraphX session runs so lazy compiles cannot overlap within a worker. Use a fixed face-recognition batch size for MIGraphX to avoid compiling a new program for each detected face count. * fix(ml): increase ROCm worker timeout * fix(ml): narrow MIGraphX compile locking * docs: format environment variables table * docs: apply prettier to environment variables table
This commit is contained in:
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from socket import socket
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
@@ -42,6 +42,10 @@ class MaxBatchSize(BaseModel):
|
||||
ocr: int | None = None
|
||||
|
||||
|
||||
def default_worker_timeout() -> int:
|
||||
return 900 if os.environ.get("DEVICE") == "rocm" else 300
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="MACHINE_LEARNING_",
|
||||
@@ -54,7 +58,7 @@ class Settings(BaseSettings):
|
||||
model_ttl: int = 300
|
||||
model_ttl_poll_s: int = 10
|
||||
workers: int = 1
|
||||
worker_timeout: int = 300
|
||||
worker_timeout: int = Field(default_factory=default_worker_timeout)
|
||||
http_keepalive_timeout_s: int = 2
|
||||
test_full: bool = False
|
||||
request_threads: int = os.cpu_count() or 4
|
||||
|
||||
@@ -89,4 +89,10 @@ class FaceRecognizer(InferenceModel):
|
||||
@property
|
||||
def _batch_size_default(self) -> int | None:
|
||||
providers = ort.get_available_providers()
|
||||
return None if self.model_format == ModelFormat.ONNX and "OpenVINOExecutionProvider" not in providers else 1
|
||||
if (
|
||||
self.model_format == ModelFormat.ONNX
|
||||
and "MIGraphXExecutionProvider" not in providers
|
||||
and "OpenVINOExecutionProvider" not in providers
|
||||
):
|
||||
return None
|
||||
return 1
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -12,6 +13,37 @@ from immich_ml.schemas import ModelPrecision, SessionNode
|
||||
|
||||
from ..config import log, settings
|
||||
|
||||
MigraphxInputSignature = tuple[tuple[str, str, tuple[int, ...]], ...]
|
||||
|
||||
_migraphx_registry_lock = Lock()
|
||||
_migraphx_model_locks: dict[str, Lock] = {}
|
||||
_migraphx_compiled_inputs: set[tuple[str, MigraphxInputSignature]] = set()
|
||||
|
||||
|
||||
def _migraphx_get_model_lock(model_key: str) -> Lock:
|
||||
with _migraphx_registry_lock:
|
||||
lock = _migraphx_model_locks.get(model_key)
|
||||
if lock is None:
|
||||
lock = Lock()
|
||||
_migraphx_model_locks[model_key] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def _migraphx_has_compiled_input(key: tuple[str, MigraphxInputSignature]) -> bool:
|
||||
with _migraphx_registry_lock:
|
||||
return key in _migraphx_compiled_inputs
|
||||
|
||||
|
||||
def _migraphx_mark_compiled_input(key: tuple[str, MigraphxInputSignature]) -> None:
|
||||
with _migraphx_registry_lock:
|
||||
_migraphx_compiled_inputs.add(key)
|
||||
|
||||
|
||||
def _migraphx_input_signature(
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
) -> MigraphxInputSignature:
|
||||
return tuple((name, str(value.dtype), tuple(value.shape)) for name, value in sorted(input_feed.items()))
|
||||
|
||||
|
||||
class OrtSession:
|
||||
session: ort.InferenceSession
|
||||
@@ -48,7 +80,21 @@ class OrtSession:
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
outputs: list[NDArray[np.float32]] = self.session.run(output_names, input_feed, run_options)
|
||||
if "MIGraphXExecutionProvider" in self.providers:
|
||||
model_key = self.model_path.resolve().as_posix()
|
||||
input_key = (model_key, _migraphx_input_signature(input_feed))
|
||||
if not _migraphx_has_compiled_input(input_key):
|
||||
model_lock = _migraphx_get_model_lock(model_key)
|
||||
with model_lock:
|
||||
if not _migraphx_has_compiled_input(input_key):
|
||||
outputs: list[NDArray[np.float32]] = self.session.run(output_names, input_feed, run_options)
|
||||
_migraphx_mark_compiled_input(input_key)
|
||||
return outputs
|
||||
|
||||
outputs = self.session.run(output_names, input_feed, run_options)
|
||||
return outputs
|
||||
|
||||
outputs = self.session.run(output_names, input_feed, run_options)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
|
||||
@@ -35,7 +35,37 @@ from immich_ml.sessions.ort import OrtSession
|
||||
from immich_ml.sessions.rknn import RknnSession, run_inference
|
||||
|
||||
|
||||
class FakeLock:
|
||||
def __init__(self) -> None:
|
||||
self.enter = mock.Mock()
|
||||
self.exit = mock.Mock()
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.enter()
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.exit(*args)
|
||||
|
||||
|
||||
class TestBase:
|
||||
def test_sets_default_worker_timeout(self, monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("DEVICE", raising=False)
|
||||
monkeypatch.delenv("MACHINE_LEARNING_WORKER_TIMEOUT", raising=False)
|
||||
|
||||
assert Settings().worker_timeout == 300
|
||||
|
||||
def test_sets_rocm_default_worker_timeout(self, monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("DEVICE", "rocm")
|
||||
monkeypatch.delenv("MACHINE_LEARNING_WORKER_TIMEOUT", raising=False)
|
||||
|
||||
assert Settings().worker_timeout == 900
|
||||
|
||||
def test_worker_timeout_env_override(self, monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("DEVICE", "rocm")
|
||||
monkeypatch.setenv("MACHINE_LEARNING_WORKER_TIMEOUT", "1200")
|
||||
|
||||
assert Settings().worker_timeout == 1200
|
||||
|
||||
def test_sets_default_cache_dir(self) -> None:
|
||||
encoder = OpenClipTextualEncoder("ViT-B-32__openai")
|
||||
|
||||
@@ -413,6 +443,52 @@ class TestOrtSession:
|
||||
|
||||
assert sess_options is session.sess_options
|
||||
|
||||
def test_serializes_rocm_first_run_for_new_input_signature(self, mocker: MockerFixture) -> None:
|
||||
lock = FakeLock()
|
||||
get_model_lock = mocker.patch("immich_ml.sessions.ort._migraphx_get_model_lock", return_value=lock)
|
||||
mocker.patch("immich_ml.sessions.ort._migraphx_compiled_inputs", set())
|
||||
mocker.patch("immich_ml.sessions.ort.Path.mkdir")
|
||||
session = OrtSession("/cache/ViT-B-32__openai/model.onnx", providers=["MIGraphXExecutionProvider"])
|
||||
input_feed = {"input": np.random.rand(1, 3, 224, 224).astype(np.float32)}
|
||||
|
||||
session.run(None, input_feed)
|
||||
session.run(None, input_feed)
|
||||
|
||||
lock.enter.assert_called_once()
|
||||
lock.exit.assert_called_once()
|
||||
get_model_lock.assert_called_once()
|
||||
session.session.run.assert_has_calls([mock.call(None, input_feed, None), mock.call(None, input_feed, None)])
|
||||
|
||||
def test_serializes_rocm_run_for_each_new_input_signature(self, mocker: MockerFixture) -> None:
|
||||
lock = FakeLock()
|
||||
mocker.patch("immich_ml.sessions.ort._migraphx_get_model_lock", return_value=lock)
|
||||
mocker.patch("immich_ml.sessions.ort._migraphx_compiled_inputs", set())
|
||||
mocker.patch("immich_ml.sessions.ort.Path.mkdir")
|
||||
session = OrtSession("/cache/ViT-B-32__openai/model.onnx", providers=["MIGraphXExecutionProvider"])
|
||||
input_feed = {"input": np.random.rand(1, 3, 224, 224).astype(np.float32)}
|
||||
new_shape_input_feed = {"input": np.random.rand(2, 3, 224, 224).astype(np.float32)}
|
||||
|
||||
session.run(None, input_feed)
|
||||
session.run(None, new_shape_input_feed)
|
||||
|
||||
assert lock.enter.call_count == 2
|
||||
assert lock.exit.call_count == 2
|
||||
session.session.run.assert_has_calls(
|
||||
[mock.call(None, input_feed, None), mock.call(None, new_shape_input_feed, None)]
|
||||
)
|
||||
|
||||
def test_does_not_serialize_non_rocm_run(self, mocker: MockerFixture) -> None:
|
||||
lock = FakeLock()
|
||||
get_model_lock = mocker.patch("immich_ml.sessions.ort._migraphx_get_model_lock", return_value=lock)
|
||||
session = OrtSession("/cache/ViT-B-32__openai/model.onnx", providers=["CPUExecutionProvider"])
|
||||
input_feed = {"input": np.random.rand(1, 3, 224, 224).astype(np.float32)}
|
||||
|
||||
session.run(None, input_feed)
|
||||
|
||||
get_model_lock.assert_not_called()
|
||||
lock.enter.assert_not_called()
|
||||
session.session.run.assert_called_once_with(None, input_feed, None)
|
||||
|
||||
|
||||
class TestAnnSession:
|
||||
def test_creates_ann_session(self, ann_session: mock.Mock, info: mock.Mock) -> None:
|
||||
@@ -883,6 +959,34 @@ class TestFaceRecognition:
|
||||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
|
||||
def test_recognition_does_not_add_batch_axis_for_migraphx(
|
||||
self, ort_session: mock.Mock, path: mock.Mock, mocker: MockerFixture
|
||||
) -> None:
|
||||
onnx = mocker.patch("immich_ml.models.facial_recognition.recognition.onnx", autospec=True)
|
||||
update_dims = mocker.patch(
|
||||
"immich_ml.models.facial_recognition.recognition.update_inputs_outputs_dims", autospec=True
|
||||
)
|
||||
mocker.patch("immich_ml.models.base.InferenceModel.download")
|
||||
mocker.patch("immich_ml.models.facial_recognition.recognition.ArcFaceONNX")
|
||||
mocker.patch(
|
||||
"immich_ml.models.facial_recognition.recognition.ort.get_available_providers",
|
||||
return_value=["MIGraphXExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
|
||||
|
||||
inputs = [SimpleNamespace(name="input.1", shape=(1, 3, 224, 224))]
|
||||
outputs = [SimpleNamespace(name="output.1", shape=(1, 800))]
|
||||
ort_session.return_value.get_inputs.return_value = inputs
|
||||
ort_session.return_value.get_outputs.return_value = outputs
|
||||
|
||||
face_recognizer = FaceRecognizer("buffalo_s", cache_dir=path)
|
||||
face_recognizer.load()
|
||||
|
||||
assert face_recognizer.batch_size == 1
|
||||
update_dims.assert_not_called()
|
||||
onnx.load.assert_not_called()
|
||||
onnx.save.assert_not_called()
|
||||
|
||||
def test_set_custom_max_batch_size(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(settings, "max_batch_size", MaxBatchSize(facial_recognition=2))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user