disable fp16 for cpu device

This commit is contained in:
SWivid
2024-10-20 22:45:54 +08:00
parent 073092d0d3
commit b4f81425f3
2 changed files with 4 additions and 2 deletions

View File

@@ -96,7 +96,8 @@ class CFM(nn.Module):
):
self.eval()
cond = cond.half()
if cond.device != torch.device('cpu'):
cond = cond.half()
# raw wave

View File

@@ -555,7 +555,8 @@ def repetition_found(text, length = 2, tolerance = 10):
# load model checkpoint for inference
def load_checkpoint(model, ckpt_path, device, use_ema = True):
model = model.half()
if device != "cpu":
model = model.half()
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":