0.4.2 fix trainer with grad_accum

This commit is contained in:
SWivid
2025-01-15 18:28:41 +08:00
parent 12d6970271
commit 9e51878d18
2 changed files with 9 additions and 9 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.4.1"
version = "0.4.2"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}

View File

@@ -56,14 +56,8 @@ class Trainer:
if logger == "wandb" and not wandb.api.api_key:
logger = None
print(f"Using logger: {logger}")
self.log_samples = log_samples
if grad_accumulation_steps > 1 and self.is_main:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)
self.accelerator = Accelerator(
log_with=logger if logger == "wandb" else None,
kwargs_handlers=[ddp_kwargs],
@@ -106,6 +100,12 @@ class Trainer:
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
self.ema_model.to(self.accelerator.device)
print(f"Using logger: {logger}")
if grad_accumulation_steps > 1:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
@@ -357,7 +357,7 @@ class Trainer:
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
if global_update % self.save_per_updates == 0:
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update)
if self.log_samples and self.accelerator.is_local_main_process:
@@ -391,7 +391,7 @@ class Trainer:
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)
if global_update % self.last_per_updates == 0:
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
self.save_checkpoint(global_update, last=True)