Merge pull request #166 from cocktailpeanut/wandb_usability

User-friendly wandb support
This commit is contained in:
Zhikang Niu
2024-10-20 10:25:17 +08:00
committed by GitHub
2 changed files with 45 additions and 20 deletions

View File

@@ -72,6 +72,27 @@ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discuss
Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
## Wandb Logging
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
To turn on wandb logging, you can either:
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
On Mac & Linux:
```
export WANDB_API_KEY=<YOUR WANDB API KEY>
```
On Windows:
```
set WANDB_API_KEY=<YOUR WANDB API KEY>
```
## Inference
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`.

View File

@@ -50,31 +50,35 @@ class Trainer:
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
logger = "wandb" if wandb.api.api_key else None
print(f"Using logger: {logger}")
self.accelerator = Accelerator(
log_with = "wandb",
log_with = logger,
kwargs_handlers = [ddp_kwargs],
gradient_accumulation_steps = grad_accumulation_steps,
**accelerate_kwargs
)
if exists(wandb_resume_id):
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
else:
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers(
project_name = wandb_project,
init_kwargs=init_kwargs,
config={"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler}
)
if logger == "wandb":
if exists(wandb_resume_id):
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
else:
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers(
project_name = wandb_project,
init_kwargs=init_kwargs,
config={"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler}
)
self.model = model