default to weights_only=True for safer loading

This commit is contained in:
Jarod Mica
2024-10-15 00:37:46 -07:00
parent 49b465f5d8
commit 31e5051d51
4 changed files with 6 additions and 5 deletions

View File

@@ -140,7 +140,7 @@ class Trainer:
else:
latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
if self.is_main:
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])

View File

@@ -510,7 +510,7 @@ def run_sim(args):
device = f"cuda:{rank}"
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict['model'], strict=False)
use_gpu=True if torch.cuda.is_available() else False
@@ -566,7 +566,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema = True):
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path, device=device)
else:
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)

View File

@@ -127,7 +127,7 @@ local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:

View File

@@ -85,8 +85,9 @@ local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", weights_only=True, map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")