From ae6e97b83689b991de616ebf3aa3ce063f621b5c Mon Sep 17 00:00:00 2001 From: cocktailpeanut Date: Fri, 18 Oct 2024 14:59:59 -0400 Subject: [PATCH] user-friendly wandb support --- README.md | 21 +++++++++++++++++++++ model/trainer.py | 44 ++++++++++++++++++++++++-------------------- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 5efbd88..48f10fb 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,27 @@ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discuss Gradio UI finetuning with `finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143). +## Wandb Logging + +By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`). + +To turn on wandb logging, you can either: + +1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login) +2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows: + +On Mac & Linux: + +``` +export WANDB_API_KEY= +``` + +On Windows: + +``` +set WANDB_API_KEY= +``` + ## Inference The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or automatically downloaded with `inference-cli` and `gradio_app`. diff --git a/model/trainer.py b/model/trainer.py index 676a7d0..c5c956a 100644 --- a/model/trainer.py +++ b/model/trainer.py @@ -50,31 +50,35 @@ class Trainer: ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) + logger = "wandb" if wandb.api.api_key else None + print(f"Using logger: {logger}") + self.accelerator = Accelerator( - log_with = "wandb", + log_with = logger, kwargs_handlers = [ddp_kwargs], gradient_accumulation_steps = grad_accumulation_steps, **accelerate_kwargs ) - - if exists(wandb_resume_id): - init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}} - else: - init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}} - self.accelerator.init_trackers( - project_name = wandb_project, - init_kwargs=init_kwargs, - config={"epochs": epochs, - "learning_rate": learning_rate, - "num_warmup_updates": num_warmup_updates, - "batch_size": batch_size, - "batch_size_type": batch_size_type, - "max_samples": max_samples, - "grad_accumulation_steps": grad_accumulation_steps, - "max_grad_norm": max_grad_norm, - "gpus": self.accelerator.num_processes, - "noise_scheduler": noise_scheduler} - ) + + if logger == "wandb": + if exists(wandb_resume_id): + init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}} + else: + init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}} + self.accelerator.init_trackers( + project_name = wandb_project, + init_kwargs=init_kwargs, + config={"epochs": epochs, + "learning_rate": learning_rate, + "num_warmup_updates": num_warmup_updates, + "batch_size": batch_size, + "batch_size_type": batch_size_type, + "max_samples": max_samples, + "grad_accumulation_steps": grad_accumulation_steps, + "max_grad_norm": max_grad_norm, + "gpus": self.accelerator.num_processes, + "noise_scheduler": noise_scheduler} + ) self.model = model