diff --git a/src/f5_tts/configs/E2TTS_Base.yaml b/src/f5_tts/configs/E2TTS_Base.yaml index ee70182..45e6cbc 100644 --- a/src/f5_tts/configs/E2TTS_Base.yaml +++ b/src/f5_tts/configs/E2TTS_Base.yaml @@ -42,6 +42,9 @@ model: ckpts: logger: wandb # wandb | tensorboard | null + wandb_project: CFM-TTS # wandb project name + wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name + wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints diff --git a/src/f5_tts/configs/E2TTS_Small.yaml b/src/f5_tts/configs/E2TTS_Small.yaml index cbb1f44..001a4ca 100644 --- a/src/f5_tts/configs/E2TTS_Small.yaml +++ b/src/f5_tts/configs/E2TTS_Small.yaml @@ -42,6 +42,9 @@ model: ckpts: logger: wandb # wandb | tensorboard | null + wandb_project: CFM-TTS # wandb project name + wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name + wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints diff --git a/src/f5_tts/configs/F5TTS_Base.yaml b/src/f5_tts/configs/F5TTS_Base.yaml index d177674..89e0ffa 100644 --- a/src/f5_tts/configs/F5TTS_Base.yaml +++ b/src/f5_tts/configs/F5TTS_Base.yaml @@ -47,6 +47,9 @@ model: ckpts: logger: wandb # wandb | tensorboard | null + wandb_project: CFM-TTS # wandb project name + wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name + wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints diff --git a/src/f5_tts/configs/F5TTS_Small.yaml b/src/f5_tts/configs/F5TTS_Small.yaml index 396f389..ed9e86e 100644 --- a/src/f5_tts/configs/F5TTS_Small.yaml +++ b/src/f5_tts/configs/F5TTS_Small.yaml @@ -47,6 +47,9 @@ model: ckpts: logger: wandb # wandb | tensorboard | null + wandb_project: CFM-TTS # wandb project name + wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name + wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints diff --git a/src/f5_tts/configs/F5TTS_v1_Base.yaml b/src/f5_tts/configs/F5TTS_v1_Base.yaml index e931a01..a468c89 100644 --- a/src/f5_tts/configs/F5TTS_v1_Base.yaml +++ b/src/f5_tts/configs/F5TTS_v1_Base.yaml @@ -48,6 +48,9 @@ model: ckpts: logger: wandb # wandb | tensorboard | null + wandb_project: CFM-TTS # wandb project name + wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name + wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples save_per_updates: 50000 # save checkpoint per updates keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index b948ab1..84196a4 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -21,8 +21,12 @@ def main(model_cfg): tokenizer = model_cfg.model.tokenizer mel_spec_type = model_cfg.model.mel_spec.mel_spec_type - exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}" - wandb_resume_id = None + wandb_project = model_cfg.ckpts.get("wandb_project", "CFM-TTS") + wandb_run_name = model_cfg.ckpts.get( + "wandb_run_name", + f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}", + ) + wandb_resume_id = model_cfg.ckpts.get("wandb_resume_id", None) # set text tokenizer if tokenizer != "custom": @@ -53,8 +57,8 @@ def main(model_cfg): grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps, max_grad_norm=model_cfg.optim.max_grad_norm, logger=model_cfg.ckpts.logger, - wandb_project="CFM-TTS", - wandb_run_name=exp_name, + wandb_project=wandb_project, + wandb_run_name=wandb_run_name, wandb_resume_id=wandb_resume_id, last_per_updates=model_cfg.ckpts.last_per_updates, log_samples=model_cfg.ckpts.log_samples,