From 7f7fd29675d8fbc31c3333d8ed916d1267f17854 Mon Sep 17 00:00:00 2001 From: SWivid Date: Thu, 5 Dec 2024 11:50:50 +0800 Subject: [PATCH] fix. default fp32 for ZLUDA #578 --- src/f5_tts/infer/utils_infer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 42ffe57..9f17340 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -138,7 +138,11 @@ asr_pipe = None def initialize_asr_pipeline(device: str = device, dtype=None): if dtype is None: dtype = ( - torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 + torch.float16 + if "cuda" in device + and torch.cuda.get_device_properties(device).major >= 6 + and not torch.cuda.get_device_name().endswith("[ZLUDA]") + else torch.float32 ) global asr_pipe asr_pipe = pipeline( @@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None): def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True): if dtype is None: dtype = ( - torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 + torch.float16 + if "cuda" in device + and torch.cuda.get_device_properties(device).major >= 6 + and not torch.cuda.get_device_name().endswith("[ZLUDA]") + else torch.float32 ) model = model.to(dtype)