From 2d27d2c1b2bfbc65548ee5aedbea3f4dfdbe3a13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hasan=20Can=20Solako=C4=9Flu?= Date: Fri, 17 Jan 2025 19:35:19 +0300 Subject: [PATCH] Exclude pretrained models from the checkpoint rotation logic --- src/f5_tts/model/trainer.py | 28 +++++++++++++++++++++++----- src/f5_tts/train/finetune_cli.py | 3 ++- src/f5_tts/train/finetune_gradio.py | 4 +++- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index d397bab..d94c5fc 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -160,10 +160,14 @@ class Trainer: return self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt") if self.keep_last_n_checkpoints > 0: + # Updated logic to exclude pretrained model from rotation checkpoints = [ f for f in os.listdir(self.checkpoint_path) - if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt" + if f.startswith("model_") + and not f.startswith("pretrained_") # Exclude pretrained models + and f.endswith(".pt") + and f != "model_last.pt" ] checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0])) while len(checkpoints) > self.keep_last_n_checkpoints: @@ -183,10 +187,24 @@ class Trainer: if "model_last.pt" in os.listdir(self.checkpoint_path): latest_checkpoint = "model_last.pt" else: - latest_checkpoint = sorted( - [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")], - key=lambda x: int("".join(filter(str.isdigit, x))), - )[-1] + # Updated to consider pretrained models for loading but prioritize training checkpoints + all_checkpoints = [ + f + for f in os.listdir(self.checkpoint_path) + if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt") + ] + + # First try to find regular training checkpoints + training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"] + if training_checkpoints: + latest_checkpoint = sorted( + training_checkpoints, + key=lambda x: int("".join(filter(str.isdigit, x))), + )[-1] + else: + # If no training checkpoints, use pretrained model + latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_")) + # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu") diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index a148d7b..255141b 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -111,7 +111,8 @@ def main(): if not os.path.isdir(checkpoint_path): os.makedirs(checkpoint_path, exist_ok=True) - file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path)) + # Change: Add 'pretrained_' prefix to copied model + file_checkpoint = os.path.join(checkpoint_path, "pretrained_" + os.path.basename(ckpt_path)) if not os.path.isfile(file_checkpoint): shutil.copy2(ckpt_path, file_checkpoint) print("copy checkpoint for finetune") diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index e27ef3a..875a366 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1068,7 +1068,9 @@ def vocab_extend(project_name, symbols, model_type): dataset_name = name_project.replace("_pinyin", "").replace("_char", "") new_ckpt_path = os.path.join(path_project_ckpts, dataset_name) os.makedirs(new_ckpt_path, exist_ok=True) - new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt") + + # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py + new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt") size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)