0.4.0 fix gradient accumulation; change checkpointing logic to per_updates

This commit is contained in:
unknown
2025-01-12 21:26:57 +08:00
parent f992c4e844
commit 0b11f7eae6
10 changed files with 114 additions and 94 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.3.4"
version = "0.4.0"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}

View File

@@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
@@ -39,6 +39,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0
bnb_optimizer: False
@@ -39,6 +39,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
@@ -42,6 +42,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -12,7 +12,7 @@ datasets:
optim:
epochs: 15
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup steps
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
@@ -42,6 +42,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per steps
last_per_steps: 5000 # save last checkpoint per steps
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import gc
import math
import os
import torch
@@ -42,7 +43,7 @@ class Trainer:
wandb_run_name="test_run",
wandb_resume_id: str = None,
log_samples: bool = False,
last_per_steps=None,
last_per_updates=None,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
bnb_optimizer: bool = False,
@@ -57,6 +58,11 @@ class Trainer:
print(f"Using logger: {logger}")
self.log_samples = log_samples
if grad_accumulation_steps > 1 and self.is_main:
print(
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
)
self.accelerator = Accelerator(
log_with=logger if logger == "wandb" else None,
kwargs_handlers=[ddp_kwargs],
@@ -102,7 +108,7 @@ class Trainer:
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
self.last_per_updates = default(last_per_updates, save_per_updates)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
self.batch_size = batch_size
@@ -132,7 +138,7 @@ class Trainer:
def is_main(self):
return self.accelerator.is_main_process
def save_checkpoint(self, step, last=False):
def save_checkpoint(self, update, last=False):
self.accelerator.wait_for_everyone()
if self.is_main:
checkpoint = dict(
@@ -140,15 +146,15 @@ class Trainer:
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
step=step,
update=update,
)
if not os.path.exists(self.checkpoint_path):
os.makedirs(self.checkpoint_path)
if last:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at step {step}")
print(f"Saved last checkpoint at update {update}")
else:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
def load_checkpoint(self):
if (
@@ -177,7 +183,14 @@ class Trainer:
if self.is_main:
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
if "step" in checkpoint:
if "update" in checkpoint or "step" in checkpoint:
# patch for backward compatibility, with before f992c4e
if "step" in checkpoint:
checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
if self.grad_accumulation_steps > 1 and self.is_main:
print(
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
)
# patch for backward compatibility, 305e3ea
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
if key in checkpoint["model_state_dict"]:
@@ -187,19 +200,19 @@ class Trainer:
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
step = checkpoint["step"]
update = checkpoint["update"]
else:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
if k not in ["initted", "update", "step"]
}
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
step = 0
update = 0
del checkpoint
gc.collect()
return step
return update
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
if self.log_samples:
@@ -248,25 +261,26 @@ class Trainer:
# accelerator.prepare() dispatches batches to devices;
# which means the length of dataloader calculated before, should consider the number of devices
warmup_steps = (
warmup_updates = (
self.num_warmup_updates * self.accelerator.num_processes
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
# otherwise by default with split_batches=False, warmup steps change with num_processes
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
decay_steps = total_steps - warmup_steps
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
decay_updates = total_updates - warmup_updates
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
self.scheduler = SequentialLR(
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
)
train_dataloader, self.scheduler = self.accelerator.prepare(
train_dataloader, self.scheduler
) # actual steps = 1 gpu steps / gpus
start_step = self.load_checkpoint()
global_step = start_step
) # actual multi_gpu updates = single_gpu updates / gpu nums
start_update = self.load_checkpoint()
global_update = start_update
if exists(resumable_with_seed):
orig_epoch_step = len(train_dataloader)
start_step = start_update * self.grad_accumulation_steps
skipped_epoch = int(start_step // orig_epoch_step)
skipped_batch = start_step % orig_epoch_step
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
@@ -276,23 +290,21 @@ class Trainer:
for epoch in range(skipped_epoch, self.epochs):
self.model.train()
if exists(resumable_with_seed) and epoch == skipped_epoch:
progress_bar = tqdm(
skipped_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
initial=skipped_batch,
total=orig_epoch_step,
)
progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
current_dataloader = skipped_dataloader
else:
progress_bar = tqdm(
train_dataloader,
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="step",
disable=not self.accelerator.is_local_main_process,
)
progress_bar_initial = 0
current_dataloader = train_dataloader
for batch in progress_bar:
progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch+1}/{self.epochs}",
unit="update",
disable=not self.accelerator.is_local_main_process,
initial=progress_bar_initial,
)
for batch in current_dataloader:
with self.accelerator.accumulate(self.model):
text_inputs = batch["text"]
mel_spec = batch["mel"].permute(0, 2, 1)
@@ -301,7 +313,7 @@ class Trainer:
# TODO. add duration predictor training
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
loss, cond, pred = self.model(
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
@@ -318,18 +330,20 @@ class Trainer:
if self.is_main and self.accelerator.sync_gradients:
self.ema_model.update()
global_step += 1
global_update += 1
progress_bar.update(1)
progress_bar.set_postfix(update=str(global_update), loss=loss.item())
if self.accelerator.is_local_main_process:
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
self.accelerator.log(
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
)
if self.logger == "tensorboard":
self.writer.add_scalar("loss", loss.item(), global_step)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
self.save_checkpoint(global_step)
if global_update % self.save_per_updates == 0:
self.save_checkpoint(global_update)
if self.log_samples and self.accelerator.is_local_main_process:
ref_audio_len = mel_lengths[0]
@@ -355,12 +369,16 @@ class Trainer:
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
torchaudio.save(
f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
)
torchaudio.save(
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)
if global_step % self.last_per_steps == 0:
self.save_checkpoint(global_step, last=True)
if global_update % self.last_per_updates == 0:
self.save_checkpoint(global_update, last=True)
self.save_checkpoint(global_step, last=True)
self.save_checkpoint(global_update, last=True)
self.accelerator.end_training()

View File

@@ -20,13 +20,13 @@ grad_accum = 1
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
steps_per_epoch = updates_per_epoch * grad_accum
# steps_per_epoch = updates_per_epoch * grad_accum
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
# others
print(f"total {total_hours:.0f} hours")

View File

@@ -27,7 +27,7 @@ def parse_args():
# num_warmup_updates = 300 for 5000 sample about 10 hours
# change save_per_updates , last_per_steps change this value what you need ,
# change save_per_updates , last_per_updates change this value what you need ,
parser = argparse.ArgumentParser(description="Train CFM Model")
@@ -44,9 +44,9 @@ def parse_args():
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
@@ -61,7 +61,7 @@ def parse_args():
parser.add_argument(
"--log_samples",
action="store_true",
help="Log inferenced samples per ckpt save steps",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
parser.add_argument(
@@ -156,7 +156,7 @@ def main():
wandb_run_name=args.exp_name,
wandb_resume_id=wandb_resume_id,
log_samples=args.log_samples,
last_per_steps=args.last_per_steps,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
)

View File

@@ -62,7 +62,7 @@ def save_settings(
epochs,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
@@ -86,7 +86,7 @@ def save_settings(
"epochs": epochs,
"num_warmup_updates": num_warmup_updates,
"save_per_updates": save_per_updates,
"last_per_steps": last_per_steps,
"last_per_updates": last_per_updates,
"finetune": finetune,
"file_checkpoint_train": file_checkpoint_train,
"tokenizer_type": tokenizer_type,
@@ -118,7 +118,7 @@ def load_settings(project_name):
"epochs": 100,
"num_warmup_updates": 2,
"save_per_updates": 300,
"last_per_steps": 100,
"last_per_updates": 100,
"finetune": True,
"file_checkpoint_train": "",
"tokenizer_type": "pinyin",
@@ -138,7 +138,7 @@ def load_settings(project_name):
settings["epochs"],
settings["num_warmup_updates"],
settings["save_per_updates"],
settings["last_per_steps"],
settings["last_per_updates"],
settings["finetune"],
settings["file_checkpoint_train"],
settings["tokenizer_type"],
@@ -154,6 +154,8 @@ def load_settings(project_name):
settings["logger"] = "wandb"
if "bnb_optimizer" not in settings:
settings["bnb_optimizer"] = False
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
return (
settings["exp_name"],
settings["learning_rate"],
@@ -165,7 +167,7 @@ def load_settings(project_name):
settings["epochs"],
settings["num_warmup_updates"],
settings["save_per_updates"],
settings["last_per_steps"],
settings["last_per_updates"],
settings["finetune"],
settings["file_checkpoint_train"],
settings["tokenizer_type"],
@@ -379,7 +381,7 @@ def start_training(
epochs=11,
num_warmup_updates=200,
save_per_updates=400,
last_per_steps=800,
last_per_updates=800,
finetune=True,
file_checkpoint_train="",
tokenizer_type="pinyin",
@@ -448,7 +450,7 @@ def start_training(
f"--epochs {epochs} "
f"--num_warmup_updates {num_warmup_updates} "
f"--save_per_updates {save_per_updates} "
f"--last_per_steps {last_per_steps} "
f"--last_per_updates {last_per_updates} "
f"--dataset_name {dataset_name}"
)
@@ -482,7 +484,7 @@ def start_training(
epochs,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
@@ -880,7 +882,7 @@ def calculate_train(
learning_rate,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
finetune,
):
path_project = os.path.join(path_data, name_project)
@@ -892,7 +894,7 @@ def calculate_train(
max_samples,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
"project not found !",
learning_rate,
)
@@ -940,14 +942,14 @@ def calculate_train(
num_warmup_updates = int(samples * 0.05)
save_per_updates = int(samples * 0.10)
last_per_steps = int(save_per_updates * 0.25)
last_per_updates = int(save_per_updates * 0.25)
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
if last_per_steps <= 0:
last_per_steps = 2
last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
if last_per_updates <= 0:
last_per_updates = 2
total_hours = hours
mel_hop_length = 256
@@ -978,7 +980,7 @@ def calculate_train(
max_samples,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
samples,
learning_rate,
int(epochs),
@@ -1530,7 +1532,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
with gr.TabItem("Train Data"):
gr.Markdown("""```plaintext
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per steps are set correctly, or change them manually as needed.
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
```""")
with gr.Row():
@@ -1561,7 +1563,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=300)
last_per_steps = gr.Number(label="Last per Steps", value=100)
last_per_updates = gr.Number(label="Last per Updates", value=100)
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
@@ -1582,7 +1584,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
epochsv,
num_warmupv_updatesv,
save_per_updatesv,
last_per_stepsv,
last_per_updatesv,
finetunev,
file_checkpoint_trainv,
tokenizer_typev,
@@ -1601,7 +1603,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
epochs.value = epochsv
num_warmup_updates.value = num_warmupv_updatesv
save_per_updates.value = save_per_updatesv
last_per_steps.value = last_per_stepsv
last_per_updates.value = last_per_updatesv
ch_finetune.value = finetunev
file_checkpoint_train.value = file_checkpoint_trainv
tokenizer_type.value = tokenizer_typev
@@ -1659,7 +1661,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
epochs,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
ch_finetune,
file_checkpoint_train,
tokenizer_type,
@@ -1682,7 +1684,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
learning_rate,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
ch_finetune,
],
outputs=[
@@ -1690,7 +1692,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
max_samples,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
lb_samples,
learning_rate,
epochs,
@@ -1713,7 +1715,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
epochs,
num_warmup_updates,
save_per_updates,
last_per_steps,
last_per_updates,
ch_finetune,
file_checkpoint_train,
tokenizer_type,

View File

@@ -55,7 +55,7 @@ def main(cfg):
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_steps=cfg.ckpts.last_per_steps,
last_per_updates=cfg.ckpts.last_per_updates,
log_samples=True,
bnb_optimizer=cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,