fix. default fp32 for ZLUDA #578

This commit is contained in:
SWivid
2024-12-05 11:50:50 +08:00
parent eea65de823
commit 7f7fd29675

View File

@@ -138,7 +138,11 @@ asr_pipe = None
def initialize_asr_pipeline(device: str = device, dtype=None):
if dtype is None:
dtype = (
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
@@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None):
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
if dtype is None:
dtype = (
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
model = model.to(dtype)