From b4f81425f38da7d74fa80ddaab3f96c29696349b Mon Sep 17 00:00:00 2001 From: SWivid Date: Sun, 20 Oct 2024 22:45:54 +0800 Subject: [PATCH] disable fp16 for cpu device --- model/cfm.py | 3 ++- model/utils.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/model/cfm.py b/model/cfm.py index fdf4293..f70b097 100644 --- a/model/cfm.py +++ b/model/cfm.py @@ -96,7 +96,8 @@ class CFM(nn.Module): ): self.eval() - cond = cond.half() + if cond.device != torch.device('cpu'): + cond = cond.half() # raw wave diff --git a/model/utils.py b/model/utils.py index e6494a2..ae64b0c 100644 --- a/model/utils.py +++ b/model/utils.py @@ -555,7 +555,8 @@ def repetition_found(text, length = 2, tolerance = 10): # load model checkpoint for inference def load_checkpoint(model, ckpt_path, device, use_ema = True): - model = model.half() + if device != "cpu": + model = model.half() ckpt_type = ckpt_path.split(".")[-1] if ckpt_type == "safetensors":