diff --git a/pyproject.toml b/pyproject.toml index 13c8b78..118e0eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "0.4.1" +version = "0.4.2" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index e57a389..d397bab 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -56,14 +56,8 @@ class Trainer: if logger == "wandb" and not wandb.api.api_key: logger = None - print(f"Using logger: {logger}") self.log_samples = log_samples - if grad_accumulation_steps > 1 and self.is_main: - print( - "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e" - ) - self.accelerator = Accelerator( log_with=logger if logger == "wandb" else None, kwargs_handlers=[ddp_kwargs], @@ -106,6 +100,12 @@ class Trainer: self.ema_model = EMA(model, include_online_model=False, **ema_kwargs) self.ema_model.to(self.accelerator.device) + print(f"Using logger: {logger}") + if grad_accumulation_steps > 1: + print( + "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e" + ) + self.epochs = epochs self.num_warmup_updates = num_warmup_updates self.save_per_updates = save_per_updates @@ -357,7 +357,7 @@ class Trainer: self.writer.add_scalar("loss", loss.item(), global_update) self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update) - if global_update % self.save_per_updates == 0: + if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients: self.save_checkpoint(global_update) if self.log_samples and self.accelerator.is_local_main_process: @@ -391,7 +391,7 @@ class Trainer: f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate ) - if global_update % self.last_per_updates == 0: + if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients: self.save_checkpoint(global_update, last=True) self.save_checkpoint(global_update, last=True)