mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-05 09:39:52 -08:00
0.4.2 fix trainer with grad_accum
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -56,14 +56,8 @@ class Trainer:
|
||||
|
||||
if logger == "wandb" and not wandb.api.api_key:
|
||||
logger = None
|
||||
print(f"Using logger: {logger}")
|
||||
self.log_samples = log_samples
|
||||
|
||||
if grad_accumulation_steps > 1 and self.is_main:
|
||||
print(
|
||||
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
|
||||
)
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with=logger if logger == "wandb" else None,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
@@ -106,6 +100,12 @@ class Trainer:
|
||||
self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
|
||||
self.ema_model.to(self.accelerator.device)
|
||||
|
||||
print(f"Using logger: {logger}")
|
||||
if grad_accumulation_steps > 1:
|
||||
print(
|
||||
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
|
||||
)
|
||||
|
||||
self.epochs = epochs
|
||||
self.num_warmup_updates = num_warmup_updates
|
||||
self.save_per_updates = save_per_updates
|
||||
@@ -357,7 +357,7 @@ class Trainer:
|
||||
self.writer.add_scalar("loss", loss.item(), global_update)
|
||||
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
|
||||
|
||||
if global_update % self.save_per_updates == 0:
|
||||
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
|
||||
self.save_checkpoint(global_update)
|
||||
|
||||
if self.log_samples and self.accelerator.is_local_main_process:
|
||||
@@ -391,7 +391,7 @@ class Trainer:
|
||||
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
|
||||
)
|
||||
|
||||
if global_update % self.last_per_updates == 0:
|
||||
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
Reference in New Issue
Block a user