fp16 inference only for cuda devices now

This commit is contained in:
SWivid
2024-10-21 03:34:28 +08:00
parent bd16a8c281
commit d3badb95cf
2 changed files with 2 additions and 2 deletions

View File

@@ -96,7 +96,7 @@ class CFM(nn.Module):
):
self.eval()
if cond.device != torch.device('cpu'):
if cond.device == torch.device('cuda'):
cond = cond.half()
# raw wave

View File

@@ -555,7 +555,7 @@ def repetition_found(text, length = 2, tolerance = 10):
# load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device, use_ema = True):
if device != "cpu":
if device == "cuda":
model = model.half()
ckpt_type = ckpt_path.split(".")[-1]