From 12d6970271f5cdb91938f8ee7b2bbc60e60a0ea8 Mon Sep 17 00:00:00 2001 From: SWivid Date: Wed, 15 Jan 2025 15:06:55 +0800 Subject: [PATCH] 0.4.1 #718 add keep_last_n_checkpoints option --- pyproject.toml | 2 +- src/f5_tts/configs/E2TTS_Base_train.yaml | 2 +- src/f5_tts/configs/E2TTS_Small_train.yaml | 2 +- src/f5_tts/configs/F5TTS_Base_train.yaml | 2 +- src/f5_tts/configs/F5TTS_Small_train.yaml | 2 +- src/f5_tts/model/trainer.py | 20 +-- src/f5_tts/train/finetune_cli.py | 14 +- src/f5_tts/train/finetune_gradio.py | 165 +++++++--------------- src/f5_tts/train/train.py | 2 +- 9 files changed, 68 insertions(+), 143 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7912069..13c8b78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "0.4.0" +version = "0.4.1" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} diff --git a/src/f5_tts/configs/E2TTS_Base_train.yaml b/src/f5_tts/configs/E2TTS_Base_train.yaml index 5874a7c..da23b05 100644 --- a/src/f5_tts/configs/E2TTS_Base_train.yaml +++ b/src/f5_tts/configs/E2TTS_Base_train.yaml @@ -40,6 +40,6 @@ model: ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates + keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates - keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/E2TTS_Small_train.yaml b/src/f5_tts/configs/E2TTS_Small_train.yaml index a14bf41..b2d1a6c 100644 --- a/src/f5_tts/configs/E2TTS_Small_train.yaml +++ b/src/f5_tts/configs/E2TTS_Small_train.yaml @@ -40,6 +40,6 @@ model: ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates + keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates - keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/F5TTS_Base_train.yaml b/src/f5_tts/configs/F5TTS_Base_train.yaml index f4a6a00..ff8639f 100644 --- a/src/f5_tts/configs/F5TTS_Base_train.yaml +++ b/src/f5_tts/configs/F5TTS_Base_train.yaml @@ -43,6 +43,6 @@ model: ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates + keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates - keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/configs/F5TTS_Small_train.yaml b/src/f5_tts/configs/F5TTS_Small_train.yaml index e2ad2cc..790be06 100644 --- a/src/f5_tts/configs/F5TTS_Small_train.yaml +++ b/src/f5_tts/configs/F5TTS_Small_train.yaml @@ -43,6 +43,6 @@ model: ckpts: logger: wandb # wandb | tensorboard | None save_per_updates: 50000 # save checkpoint per updates + keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints last_per_updates: 5000 # save last checkpoint per updates - keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} \ No newline at end of file diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index f96fe9c..e57a389 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -30,6 +30,7 @@ class Trainer: learning_rate, num_warmup_updates=20000, save_per_updates=1000, + keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints checkpoint_path=None, batch_size=32, batch_size_type: str = "sample", @@ -50,17 +51,7 @@ class Trainer: mel_spec_type: str = "vocos", # "vocos" | "bigvgan" is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path - keep_last_n_checkpoints: int - | None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints ): - # Validate keep_last_n_checkpoints - if not isinstance(keep_last_n_checkpoints, int): - raise ValueError("keep_last_n_checkpoints must be an integer") - if keep_last_n_checkpoints < -1: - raise ValueError( - "keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer" - ) - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) if logger == "wandb" and not wandb.api.api_key: @@ -118,6 +109,7 @@ class Trainer: self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates + self.keep_last_n_checkpoints = keep_last_n_checkpoints self.last_per_updates = default(last_per_updates, save_per_updates) self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts") @@ -144,8 +136,6 @@ class Trainer: self.optimizer = AdamW(model.parameters(), lr=learning_rate) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) - self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None - @property def is_main(self): return self.accelerator.is_main_process @@ -166,22 +156,16 @@ class Trainer: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") print(f"Saved last checkpoint at update {update}") else: - # Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0 if self.keep_last_n_checkpoints == 0: return - self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt") - # Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive if self.keep_last_n_checkpoints > 0: - # Get all checkpoint files except model_last.pt checkpoints = [ f for f in os.listdir(self.checkpoint_path) if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt" ] - # Sort by step number checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0])) - # Remove old checkpoints if we have more than keep_last_n_checkpoints while len(checkpoints) > self.keep_last_n_checkpoints: oldest_checkpoint = checkpoints.pop(0) os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint)) diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 0c3bdc1..a148d7b 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -46,6 +46,12 @@ def parse_args(): parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs") parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates") parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates") + parser.add_argument( + "--keep_last_n_checkpoints", + type=int, + default=-1, + help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints", + ) parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates") parser.add_argument("--finetune", action="store_true", help="Use Finetune") parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint") @@ -69,12 +75,6 @@ def parse_args(): action="store_true", help="Use 8-bit Adam optimizer from bitsandbytes", ) - parser.add_argument( - "--keep_last_n_checkpoints", - type=int, - default=-1, - help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints", - ) return parser.parse_args() @@ -151,6 +151,7 @@ def main(): args.learning_rate, num_warmup_updates=args.num_warmup_updates, save_per_updates=args.save_per_updates, + keep_last_n_checkpoints=args.keep_last_n_checkpoints, checkpoint_path=checkpoint_path, batch_size=args.batch_size_per_gpu, batch_size_type=args.batch_size_type, @@ -164,7 +165,6 @@ def main(): log_samples=args.log_samples, last_per_updates=args.last_per_updates, bnb_optimizer=args.bnb_optimizer, - keep_last_n_checkpoints=args.keep_last_n_checkpoints, ) train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs) diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 79a5c67..e27ef3a 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -62,6 +62,7 @@ def save_settings( epochs, num_warmup_updates, save_per_updates, + keep_last_n_checkpoints, last_per_updates, finetune, file_checkpoint_train, @@ -70,7 +71,6 @@ def save_settings( mixed_precision, logger, ch_8bit_adam, - keep_last_n_checkpoints, ): path_project = os.path.join(path_project_ckpts, project_name) os.makedirs(path_project, exist_ok=True) @@ -87,6 +87,7 @@ def save_settings( "epochs": epochs, "num_warmup_updates": num_warmup_updates, "save_per_updates": save_per_updates, + "keep_last_n_checkpoints": keep_last_n_checkpoints, "last_per_updates": last_per_updates, "finetune": finetune, "file_checkpoint_train": file_checkpoint_train, @@ -95,7 +96,6 @@ def save_settings( "mixed_precision": mixed_precision, "logger": logger, "bnb_optimizer": ch_8bit_adam, - "keep_last_n_checkpoints": keep_last_n_checkpoints, } with open(file_setting, "w") as f: json.dump(settings, f, indent=4) @@ -120,6 +120,7 @@ def load_settings(project_name): "epochs": 100, "num_warmup_updates": 2, "save_per_updates": 300, + "keep_last_n_checkpoints": -1, "last_per_updates": 100, "finetune": True, "file_checkpoint_train": "", @@ -128,61 +129,20 @@ def load_settings(project_name): "mixed_precision": "none", "logger": "wandb", "bnb_optimizer": False, - "keep_last_n_checkpoints": -1, # Default to keep all checkpoints } - return ( - settings["exp_name"], - settings["learning_rate"], - settings["batch_size_per_gpu"], - settings["batch_size_type"], - settings["max_samples"], - settings["grad_accumulation_steps"], - settings["max_grad_norm"], - settings["epochs"], - settings["num_warmup_updates"], - settings["save_per_updates"], - settings["last_per_updates"], - settings["finetune"], - settings["file_checkpoint_train"], - settings["tokenizer_type"], - settings["tokenizer_file"], - settings["mixed_precision"], - settings["logger"], - settings["bnb_optimizer"], - settings["keep_last_n_checkpoints"], - ) + else: + with open(file_setting, "r") as f: + settings = json.load(f) + if "logger" not in settings: + settings["logger"] = "wandb" + if "bnb_optimizer" not in settings: + settings["bnb_optimizer"] = False + if "keep_last_n_checkpoints" not in settings: + settings["keep_last_n_checkpoints"] = -1 # default to keep all checkpoints + if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e + settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"] - with open(file_setting, "r") as f: - settings = json.load(f) - if "logger" not in settings: - settings["logger"] = "wandb" - if "bnb_optimizer" not in settings: - settings["bnb_optimizer"] = False - if "keep_last_n_checkpoints" not in settings: - settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints - if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e - settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"] - return ( - settings["exp_name"], - settings["learning_rate"], - settings["batch_size_per_gpu"], - settings["batch_size_type"], - settings["max_samples"], - settings["grad_accumulation_steps"], - settings["max_grad_norm"], - settings["epochs"], - settings["num_warmup_updates"], - settings["save_per_updates"], - settings["last_per_updates"], - settings["finetune"], - settings["file_checkpoint_train"], - settings["tokenizer_type"], - settings["tokenizer_file"], - settings["mixed_precision"], - settings["logger"], - settings["bnb_optimizer"], - settings["keep_last_n_checkpoints"], - ) + return settings # Load metadata @@ -388,6 +348,7 @@ def start_training( epochs=11, num_warmup_updates=200, save_per_updates=400, + keep_last_n_checkpoints=-1, last_per_updates=800, finetune=True, file_checkpoint_train="", @@ -397,7 +358,6 @@ def start_training( stream=False, logger="wandb", ch_8bit_adam=False, - keep_last_n_checkpoints=-1, ): global training_process, tts_api, stop_signal @@ -448,19 +408,19 @@ def start_training( fp16 = "" cmd = ( - f"accelerate launch {fp16} {file_train} --exp_name {exp_name} " - f"--learning_rate {learning_rate} " - f"--batch_size_per_gpu {batch_size_per_gpu} " - f"--batch_size_type {batch_size_type} " - f"--max_samples {max_samples} " - f"--grad_accumulation_steps {grad_accumulation_steps} " - f"--max_grad_norm {max_grad_norm} " - f"--epochs {epochs} " - f"--num_warmup_updates {num_warmup_updates} " - f"--save_per_updates {save_per_updates} " - f"--last_per_updates {last_per_updates} " - f"--dataset_name {dataset_name} " - f"--keep_last_n_checkpoints {keep_last_n_checkpoints}" + f"accelerate launch {fp16} {file_train} --exp_name {exp_name}" + f" --learning_rate {learning_rate}" + f" --batch_size_per_gpu {batch_size_per_gpu}" + f" --batch_size_type {batch_size_type}" + f" --max_samples {max_samples}" + f" --grad_accumulation_steps {grad_accumulation_steps}" + f" --max_grad_norm {max_grad_norm}" + f" --epochs {epochs}" + f" --num_warmup_updates {num_warmup_updates}" + f" --save_per_updates {save_per_updates}" + f" --keep_last_n_checkpoints {keep_last_n_checkpoints}" + f" --last_per_updates {last_per_updates}" + f" --dataset_name {dataset_name}" ) if finetune: @@ -493,6 +453,7 @@ def start_training( epochs, num_warmup_updates, save_per_updates, + keep_last_n_checkpoints, last_per_updates, finetune, file_checkpoint_train, @@ -501,7 +462,6 @@ def start_training( mixed_precision, logger, ch_8bit_adam, - keep_last_n_checkpoints, ) try: @@ -1573,7 +1533,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle with gr.Row(): save_per_updates = gr.Number(label="Save per Updates", value=300) - last_per_updates = gr.Number(label="Last per Updates", value=100) keep_last_n_checkpoints = gr.Number( label="Keep Last N Checkpoints", value=-1, @@ -1581,6 +1540,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle precision=0, info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints", ) + last_per_updates = gr.Number(label="Last per Updates", value=100) with gr.Row(): ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer") @@ -1590,46 +1550,27 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle stop_button = gr.Button("Stop Training", interactive=False) if projects_selelect is not None: - ( - exp_namev, - learning_ratev, - batch_size_per_gpuv, - batch_size_typev, - max_samplesv, - grad_accumulation_stepsv, - max_grad_normv, - epochsv, - num_warmupv_updatesv, - save_per_updatesv, - last_per_updatesv, - finetunev, - file_checkpoint_trainv, - tokenizer_typev, - tokenizer_filev, - mixed_precisionv, - cd_loggerv, - ch_8bit_adamv, - keep_last_n_checkpointsv, - ) = load_settings(projects_selelect) - exp_name.value = exp_namev - learning_rate.value = learning_ratev - batch_size_per_gpu.value = batch_size_per_gpuv - batch_size_type.value = batch_size_typev - max_samples.value = max_samplesv - grad_accumulation_steps.value = grad_accumulation_stepsv - max_grad_norm.value = max_grad_normv - epochs.value = epochsv - num_warmup_updates.value = num_warmupv_updatesv - save_per_updates.value = save_per_updatesv - last_per_updates.value = last_per_updatesv - ch_finetune.value = finetunev - file_checkpoint_train.value = file_checkpoint_trainv - tokenizer_type.value = tokenizer_typev - tokenizer_file.value = tokenizer_filev - mixed_precision.value = mixed_precisionv - cd_logger.value = cd_loggerv - ch_8bit_adam.value = ch_8bit_adamv - keep_last_n_checkpoints.value = keep_last_n_checkpointsv + settings = load_settings(projects_selelect) + + exp_name.value = settings["exp_name"] + learning_rate.value = settings["learning_rate"] + batch_size_per_gpu.value = settings["batch_size_per_gpu"] + batch_size_type.value = settings["batch_size_type"] + max_samples.value = settings["max_samples"] + grad_accumulation_steps.value = settings["grad_accumulation_steps"] + max_grad_norm.value = settings["max_grad_norm"] + epochs.value = settings["epochs"] + num_warmup_updates.value = settings["num_warmup_updates"] + save_per_updates.value = settings["save_per_updates"] + keep_last_n_checkpoints.value = settings["keep_last_n_checkpoints"] + last_per_updates.value = settings["last_per_updates"] + ch_finetune.value = settings["finetune"] + file_checkpoint_train.value = settings["file_checkpoint_train"] + tokenizer_type.value = settings["tokenizer_type"] + tokenizer_file.value = settings["tokenizer_file"] + mixed_precision.value = settings["mixed_precision"] + cd_logger.value = settings["logger"] + ch_8bit_adam.value = settings["bnb_optimizer"] ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True) txt_info_train = gr.Text(label="Info", value="") @@ -1680,6 +1621,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle epochs, num_warmup_updates, save_per_updates, + keep_last_n_checkpoints, last_per_updates, ch_finetune, file_checkpoint_train, @@ -1689,7 +1631,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle ch_stream, cd_logger, ch_8bit_adam, - keep_last_n_checkpoints, ], outputs=[txt_info_train, start_button, stop_button], ) diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index 4157166..ade54be 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -45,6 +45,7 @@ def main(cfg): learning_rate=cfg.optim.learning_rate, num_warmup_updates=cfg.optim.num_warmup_updates, save_per_updates=cfg.ckpts.save_per_updates, + keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1), checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")), batch_size=cfg.datasets.batch_size_per_gpu, batch_size_type=cfg.datasets.batch_size_type, @@ -61,7 +62,6 @@ def main(cfg): mel_spec_type=mel_spec_type, is_local_vocoder=cfg.model.vocoder.is_local, local_vocoder_path=cfg.model.vocoder.local_path, - keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None), ) train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)