diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 6457f45..7819447 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -974,7 +974,7 @@ def calculate_train( def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str: try: - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, weights_only=True) print("Original Checkpoint Keys:", checkpoint.keys()) ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)