diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 42ffe57..9f17340 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -138,7 +138,11 @@ asr_pipe = None def initialize_asr_pipeline(device: str = device, dtype=None): if dtype is None: dtype = ( - torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 + torch.float16 + if "cuda" in device + and torch.cuda.get_device_properties(device).major >= 6 + and not torch.cuda.get_device_name().endswith("[ZLUDA]") + else torch.float32 ) global asr_pipe asr_pipe = pipeline( @@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None): def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True): if dtype is None: dtype = ( - torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32 + torch.float16 + if "cuda" in device + and torch.cuda.get_device_properties(device).major >= 6 + and not torch.cuda.get_device_name().endswith("[ZLUDA]") + else torch.float32 ) model = model.to(dtype)