diff --git a/model/trainer.py b/model/trainer.py index fbd15a9..676a7d0 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -140,7 +140,7 @@ class Trainer: else: latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1] # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ - checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu") + checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") if self.is_main: self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) diff --git a/model/utils.py b/model/utils.py index 818df12..45a338d 100644 --- a/model/utils.py +++ b/model/utils.py @@ -510,7 +510,7 @@ def run_sim(args): device = f"cuda:{rank}" model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None) - state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage) + state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict['model'], strict=False) use_gpu=True if torch.cuda.is_available() else False @@ -566,7 +566,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True): from safetensors.torch import load_file checkpoint = load_file(ckpt_path, device=device) else: - checkpoint = torch.load(ckpt_path, map_location=device) + checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device) if use_ema == True: ema_model = EMA(model, include_online_model = False).to(device) diff --git a/scripts/eval_infer_batch.py b/scripts/eval_infer_batch.py index 726eb93..d13cc20 100644 --- a/scripts/eval_infer_batch.py +++ b/scripts/eval_infer_batch.py @@ -127,7 +127,7 @@ local = False if local: vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) + state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device) vocos.load_state_dict(state_dict) vocos.eval() else: diff --git a/speech_edit.py b/speech_edit.py index b83c335..991eac4 100644 --- a/speech_edit.py +++ b/speech_edit.py @@ -85,8 +85,9 @@ local = False if local: vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) + state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device) vocos.load_state_dict(state_dict) + vocos.eval() else: vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")