Merge pull request #1266 from ZhikangNiu/main

Make wandb project/run_name/resume_id configurable via Hydra yaml, backward compatible with defaults
This commit is contained in:
Yushen CHEN
2026-02-16 11:10:17 +08:00
committed by GitHub
6 changed files with 23 additions and 4 deletions
+3
View File
@@ -42,6 +42,9 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
+3
View File
@@ -42,6 +42,9 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
+3
View File
@@ -47,6 +47,9 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
+3
View File
@@ -47,6 +47,9 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
+3
View File
@@ -48,6 +48,9 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | null
wandb_project: CFM-TTS # wandb project name
wandb_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name} # wandb run name
wandb_resume_id: null # wandb run id for resuming, null to auto-detect from checkpoint
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
+8 -4
View File
@@ -21,8 +21,12 @@ def main(model_cfg):
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None
wandb_project = model_cfg.ckpts.get("wandb_project", "CFM-TTS")
wandb_run_name = model_cfg.ckpts.get(
"wandb_run_name",
f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}",
)
wandb_resume_id = model_cfg.ckpts.get("wandb_resume_id", None)
# set text tokenizer
if tokenizer != "custom":
@@ -53,8 +57,8 @@ def main(model_cfg):
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_project=wandb_project,
wandb_run_name=wandb_run_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,