diff --git a/src/f5_tts/configs/E2TTS_Base_train.yaml b/src/f5_tts/configs/E2TTS_Base_train.yaml index 4bdf2b6..5874a7c 100644 --- a/src/f5_tts/configs/E2TTS_Base_train.yaml +++ b/src/f5_tts/configs/E2TTS_Base_train.yaml @@ -41,4 +41,5 @@ ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates last_per_updates: 5000 # save last checkpoint per updates + keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/E2TTS_Small_train.yaml b/src/f5_tts/configs/E2TTS_Small_train.yaml index c95e26b..a14bf41 100644 --- a/src/f5_tts/configs/E2TTS_Small_train.yaml +++ b/src/f5_tts/configs/E2TTS_Small_train.yaml @@ -41,4 +41,5 @@ ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates last_per_updates: 5000 # save last checkpoint per updates + keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/F5TTS_Base_train.yaml b/src/f5_tts/configs/F5TTS_Base_train.yaml index e3bcbe3..f4a6a00 100644 --- a/src/f5_tts/configs/F5TTS_Base_train.yaml +++ b/src/f5_tts/configs/F5TTS_Base_train.yaml @@ -44,4 +44,5 @@ ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates last_per_updates: 5000 # save last checkpoint per updates + keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/F5TTS_Small_train.yaml b/src/f5_tts/configs/F5TTS_Small_train.yaml index 460a2e4..e2ad2cc 100644 --- a/src/f5_tts/configs/F5TTS_Small_train.yaml +++ b/src/f5_tts/configs/F5TTS_Small_train.yaml @@ -44,4 +44,5 @@ ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates last_per_updates: 5000 # save last checkpoint per updates + keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 7295571..f96fe9c 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -50,7 +50,17 @@ class Trainer: mel_spec_type: str = "vocos", # "vocos" | "bigvgan" is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path + keep_last_n_checkpoints: int + | None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints ): + # Validate keep_last_n_checkpoints + if not isinstance(keep_last_n_checkpoints, int): + raise ValueError("keep_last_n_checkpoints must be an integer") + if keep_last_n_checkpoints < -1: + raise ValueError( + "keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer" + ) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) if logger == "wandb" and not wandb.api.api_key: @@ -134,6 +144,8 @@ class Trainer: self.optimizer = AdamW(model.parameters(), lr=learning_rate) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None + @property def is_main(self): return self.accelerator.is_main_process @@ -154,7 +166,26 @@ class Trainer: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") print(f"Saved last checkpoint at update {update}") else: + # Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0 + if self.keep_last_n_checkpoints == 0: + return + self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt") + # Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive + if self.keep_last_n_checkpoints > 0: + # Get all checkpoint files except model_last.pt + checkpoints = [ + f + for f in os.listdir(self.checkpoint_path) + if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt" + ] + # Sort by step number + checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0])) + # Remove old checkpoints if we have more than keep_last_n_checkpoints + while len(checkpoints) > self.keep_last_n_checkpoints: + oldest_checkpoint = checkpoints.pop(0) + os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint)) + print(f"Removed old checkpoint: {oldest_checkpoint}") def load_checkpoint(self): if ( diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index cd6fcce..0c3bdc1 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -69,6 +69,12 @@ def parse_args(): action="store_true", help="Use 8-bit Adam optimizer from bitsandbytes", ) + parser.add_argument( + "--keep_last_n_checkpoints", + type=int, + default=-1, + help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints", + ) return parser.parse_args() @@ -158,6 +164,7 @@ def main(): log_samples=args.log_samples, last_per_updates=args.last_per_updates, bnb_optimizer=args.bnb_optimizer, + keep_last_n_checkpoints=args.keep_last_n_checkpoints, ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 90602a3..79a5c67 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -70,6 +70,7 @@ def save_settings( mixed_precision, logger, ch_8bit_adam, + keep_last_n_checkpoints, ): path_project = os.path.join(path_project_ckpts, project_name) os.makedirs(path_project, exist_ok=True) @@ -94,6 +95,7 @@ def save_settings( "mixed_precision": mixed_precision, "logger": logger, "bnb_optimizer": ch_8bit_adam, + "keep_last_n_checkpoints": keep_last_n_checkpoints, } with open(file_setting, "w") as f: json.dump(settings, f, indent=4) @@ -126,6 +128,7 @@ def load_settings(project_name): "mixed_precision": "none", "logger": "wandb", "bnb_optimizer": False, + "keep_last_n_checkpoints": -1, # Default to keep all checkpoints } return ( settings["exp_name"], @@ -146,6 +149,7 @@ def load_settings(project_name): settings["mixed_precision"], settings["logger"], settings["bnb_optimizer"], + settings["keep_last_n_checkpoints"], ) with open(file_setting, "r") as f: @@ -154,6 +158,8 @@ def load_settings(project_name): settings["logger"] = "wandb" if "bnb_optimizer" not in settings: settings["bnb_optimizer"] = False + if "keep_last_n_checkpoints" not in settings: + settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"] return ( @@ -175,6 +181,7 @@ def load_settings(project_name): settings["mixed_precision"], settings["logger"], settings["bnb_optimizer"], + settings["keep_last_n_checkpoints"], ) @@ -390,6 +397,7 @@ def start_training( stream=False, logger="wandb", ch_8bit_adam=False, + keep_last_n_checkpoints=-1, ): global training_process, tts_api, stop_signal @@ -451,7 +459,8 @@ def start_training( f"--num_warmup_updates {num_warmup_updates} " f"--save_per_updates {save_per_updates} " f"--last_per_updates {last_per_updates} " - f"--dataset_name {dataset_name}" + f"--dataset_name {dataset_name} " + f"--keep_last_n_checkpoints {keep_last_n_checkpoints}" ) if finetune: @@ -492,6 +501,7 @@ def start_training( mixed_precision, logger, ch_8bit_adam, + keep_last_n_checkpoints, ) try: @@ -1564,6 +1574,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle with gr.Row(): save_per_updates = gr.Number(label="Save per Updates", value=300) last_per_updates = gr.Number(label="Last per Updates", value=100) + keep_last_n_checkpoints = gr.Number( + label="Keep Last N Checkpoints", + value=-1, + step=1, + precision=0, + info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints", + ) with gr.Row(): ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer") @@ -1592,6 +1609,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle mixed_precisionv, cd_loggerv, ch_8bit_adamv, + keep_last_n_checkpointsv, ) = load_settings(projects_selelect) exp_name.value = exp_namev learning_rate.value = learning_ratev @@ -1611,6 +1629,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle mixed_precision.value = mixed_precisionv cd_logger.value = cd_loggerv ch_8bit_adam.value = ch_8bit_adamv + keep_last_n_checkpoints.value = keep_last_n_checkpointsv ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True) txt_info_train = gr.Text(label="Info", value="") @@ -1670,6 +1689,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle ch_stream, cd_logger, ch_8bit_adam, + keep_last_n_checkpoints, ], outputs=[txt_info_train, start_button, stop_button], ) diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 69762e4..4157166 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -61,6 +61,7 @@ def main(cfg): mel_spec_type=mel_spec_type, is_local_vocoder=cfg.model.vocoder.is_local, local_vocoder_path=cfg.model.vocoder.local_path, + keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None), ) train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)