diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index b1bab71..7a405c2 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -51,7 +51,7 @@ class Trainer: mel_spec_type: str = "vocos", # "vocos" | "bigvgan" is_local_vocoder: bool = False, # use local path vocoder local_vocoder_path: str = "", # local vocoder path - cfg_dict: dict = dict(), # training config + model_cfg_dict: dict = dict(), # training config ): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) @@ -73,8 +73,8 @@ class Trainer: else: init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} - if not cfg_dict: - cfg_dict = { + if not model_cfg_dict: + model_cfg_dict = { "epochs": epochs, "learning_rate": learning_rate, "num_warmup_updates": num_warmup_updates, @@ -85,11 +85,11 @@ class Trainer: "max_grad_norm": max_grad_norm, "noise_scheduler": noise_scheduler, } - cfg_dict["gpus"] = self.accelerator.num_processes + model_cfg_dict["gpus"] = self.accelerator.num_processes self.accelerator.init_trackers( project_name=wandb_project, init_kwargs=init_kwargs, - config=cfg_dict, + config=model_cfg_dict, ) elif self.logger == "tensorboard":