mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-05 09:39:52 -08:00
fp16 inference only for cuda devices now
This commit is contained in:
@@ -96,7 +96,7 @@ class CFM(nn.Module):
|
||||
):
|
||||
self.eval()
|
||||
|
||||
if cond.device != torch.device('cpu'):
|
||||
if cond.device == torch.device('cuda'):
|
||||
cond = cond.half()
|
||||
|
||||
# raw wave
|
||||
|
||||
@@ -555,7 +555,7 @@ def repetition_found(text, length = 2, tolerance = 10):
|
||||
# load model checkpoint for inference
|
||||
|
||||
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
||||
if device != "cpu":
|
||||
if device == "cuda":
|
||||
model = model.half()
|
||||
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
|
||||
Reference in New Issue
Block a user