mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-07 02:24:51 -08:00
fix. default fp32 for ZLUDA #578
This commit is contained in:
@@ -138,7 +138,11 @@ asr_pipe = None
|
||||
def initialize_asr_pipeline(device: str = device, dtype=None):
|
||||
if dtype is None:
|
||||
dtype = (
|
||||
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
global asr_pipe
|
||||
asr_pipe = pipeline(
|
||||
@@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None):
|
||||
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
||||
if dtype is None:
|
||||
dtype = (
|
||||
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
||||
torch.float16
|
||||
if "cuda" in device
|
||||
and torch.cuda.get_device_properties(device).major >= 6
|
||||
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
||||
else torch.float32
|
||||
)
|
||||
model = model.to(dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user