diff --git a/model/cfm.py b/model/cfm.py index f70b097..58f0a9b 100644 --- a/model/cfm.py +++ b/model/cfm.py @@ -96,7 +96,7 @@ class CFM(nn.Module): ): self.eval() - if cond.device != torch.device('cpu'): + if cond.device == torch.device('cuda'): cond = cond.half() # raw wave diff --git a/model/utils.py b/model/utils.py index ae64b0c..c898d91 100644 --- a/model/utils.py +++ b/model/utils.py @@ -555,7 +555,7 @@ def repetition_found(text, length = 2, tolerance = 10): # load model checkpoint for inference def load_checkpoint(model, ckpt_path, device, use_ema = True): - if device != "cpu": + if device == "cuda": model = model.half() ckpt_type = ckpt_path.split(".")[-1]