mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-06-21 22:11:39 -07:00
Merge pull request #1266 from ZhikangNiu/main
Make wandb project/run_name/resume_id configurable via Hydra yaml, backward compatible with defaults
This commit is contained in:
@@ -42,6 +42,9 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
wandb_project: CFM-TTS # wandb project name
|
||||
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
|
||||
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
|
||||
@@ -42,6 +42,9 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
wandb_project: CFM-TTS # wandb project name
|
||||
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
|
||||
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
|
||||
@@ -47,6 +47,9 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
wandb_project: CFM-TTS # wandb project name
|
||||
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
|
||||
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
|
||||
@@ -47,6 +47,9 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
wandb_project: CFM-TTS # wandb project name
|
||||
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
|
||||
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
|
||||
@@ -48,6 +48,9 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | null
|
||||
wandb_project: CFM-TTS # wandb project name
|
||||
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
|
||||
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
|
||||
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
|
||||
@@ -21,8 +21,12 @@ def main(model_cfg):
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
|
||||
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
|
||||
wandb_resume_id = None
|
||||
wandb_project = model_cfg.ckpts.get("wandb_project", "CFM-TTS")
|
||||
wandb_run_name = model_cfg.ckpts.get(
|
||||
"wandb_run_name",
|
||||
f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}",
|
||||
)
|
||||
wandb_resume_id = model_cfg.ckpts.get("wandb_resume_id", None)
|
||||
|
||||
# set text tokenizer
|
||||
if tokenizer != "custom":
|
||||
@@ -53,8 +57,8 @@ def main(model_cfg):
|
||||
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
|
||||
max_grad_norm=model_cfg.optim.max_grad_norm,
|
||||
logger=model_cfg.ckpts.logger,
|
||||
wandb_project="CFM-TTS",
|
||||
wandb_run_name=exp_name,
|
||||
wandb_project=wandb_project,
|
||||
wandb_run_name=wandb_run_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_updates=model_cfg.ckpts.last_per_updates,
|
||||
log_samples=model_cfg.ckpts.log_samples,
|
||||
|
||||
Reference in New Issue
Block a user