mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-27 23:34:17 -08:00
0.4.0 fix gradient accumulation; change checkpointing logic to per_updates
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "0.3.4"
|
||||
version = "0.4.0"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
optim:
|
||||
epochs: 15
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup steps
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
max_grad_norm: 1.0 # gradient clipping
|
||||
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
||||
@@ -39,6 +39,6 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per steps
|
||||
last_per_steps: 5000 # save last checkpoint per steps
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
optim:
|
||||
epochs: 15
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup steps
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
max_grad_norm: 1.0
|
||||
bnb_optimizer: False
|
||||
@@ -39,6 +39,6 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per steps
|
||||
last_per_steps: 5000 # save last checkpoint per steps
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
optim:
|
||||
epochs: 15
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup steps
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
max_grad_norm: 1.0 # gradient clipping
|
||||
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
||||
@@ -42,6 +42,6 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per steps
|
||||
last_per_steps: 5000 # save last checkpoint per steps
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
optim:
|
||||
epochs: 15
|
||||
learning_rate: 7.5e-5
|
||||
num_warmup_updates: 20000 # warmup steps
|
||||
num_warmup_updates: 20000 # warmup updates
|
||||
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
||||
max_grad_norm: 1.0 # gradient clipping
|
||||
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
||||
@@ -42,6 +42,6 @@ model:
|
||||
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per steps
|
||||
last_per_steps: 5000 # save last checkpoint per steps
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -42,7 +43,7 @@ class Trainer:
|
||||
wandb_run_name="test_run",
|
||||
wandb_resume_id: str = None,
|
||||
log_samples: bool = False,
|
||||
last_per_steps=None,
|
||||
last_per_updates=None,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
ema_kwargs: dict = dict(),
|
||||
bnb_optimizer: bool = False,
|
||||
@@ -57,6 +58,11 @@ class Trainer:
|
||||
print(f"Using logger: {logger}")
|
||||
self.log_samples = log_samples
|
||||
|
||||
if grad_accumulation_steps > 1 and self.is_main:
|
||||
print(
|
||||
"Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
|
||||
)
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with=logger if logger == "wandb" else None,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
@@ -102,7 +108,7 @@ class Trainer:
|
||||
self.epochs = epochs
|
||||
self.num_warmup_updates = num_warmup_updates
|
||||
self.save_per_updates = save_per_updates
|
||||
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
|
||||
self.last_per_updates = default(last_per_updates, save_per_updates)
|
||||
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
||||
|
||||
self.batch_size = batch_size
|
||||
@@ -132,7 +138,7 @@ class Trainer:
|
||||
def is_main(self):
|
||||
return self.accelerator.is_main_process
|
||||
|
||||
def save_checkpoint(self, step, last=False):
|
||||
def save_checkpoint(self, update, last=False):
|
||||
self.accelerator.wait_for_everyone()
|
||||
if self.is_main:
|
||||
checkpoint = dict(
|
||||
@@ -140,15 +146,15 @@ class Trainer:
|
||||
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
||||
ema_model_state_dict=self.ema_model.state_dict(),
|
||||
scheduler_state_dict=self.scheduler.state_dict(),
|
||||
step=step,
|
||||
update=update,
|
||||
)
|
||||
if not os.path.exists(self.checkpoint_path):
|
||||
os.makedirs(self.checkpoint_path)
|
||||
if last:
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
||||
print(f"Saved last checkpoint at step {step}")
|
||||
print(f"Saved last checkpoint at update {update}")
|
||||
else:
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
||||
|
||||
def load_checkpoint(self):
|
||||
if (
|
||||
@@ -177,7 +183,14 @@ class Trainer:
|
||||
if self.is_main:
|
||||
self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
|
||||
|
||||
if "step" in checkpoint:
|
||||
if "update" in checkpoint or "step" in checkpoint:
|
||||
# patch for backward compatibility, with before f992c4e
|
||||
if "step" in checkpoint:
|
||||
checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
|
||||
if self.grad_accumulation_steps > 1 and self.is_main:
|
||||
print(
|
||||
"F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
|
||||
)
|
||||
# patch for backward compatibility, 305e3ea
|
||||
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
||||
if key in checkpoint["model_state_dict"]:
|
||||
@@ -187,19 +200,19 @@ class Trainer:
|
||||
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
if self.scheduler:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
step = checkpoint["step"]
|
||||
update = checkpoint["update"]
|
||||
else:
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
if k not in ["initted", "update", "step"]
|
||||
}
|
||||
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
||||
step = 0
|
||||
update = 0
|
||||
|
||||
del checkpoint
|
||||
gc.collect()
|
||||
return step
|
||||
return update
|
||||
|
||||
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
||||
if self.log_samples:
|
||||
@@ -248,25 +261,26 @@ class Trainer:
|
||||
|
||||
# accelerator.prepare() dispatches batches to devices;
|
||||
# which means the length of dataloader calculated before, should consider the number of devices
|
||||
warmup_steps = (
|
||||
warmup_updates = (
|
||||
self.num_warmup_updates * self.accelerator.num_processes
|
||||
) # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
||||
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
||||
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
||||
decay_steps = total_steps - warmup_steps
|
||||
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
|
||||
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
|
||||
total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
|
||||
decay_updates = total_updates - warmup_updates
|
||||
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
|
||||
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
|
||||
self.scheduler = SequentialLR(
|
||||
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
|
||||
self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
|
||||
)
|
||||
train_dataloader, self.scheduler = self.accelerator.prepare(
|
||||
train_dataloader, self.scheduler
|
||||
) # actual steps = 1 gpu steps / gpus
|
||||
start_step = self.load_checkpoint()
|
||||
global_step = start_step
|
||||
) # actual multi_gpu updates = single_gpu updates / gpu nums
|
||||
start_update = self.load_checkpoint()
|
||||
global_update = start_update
|
||||
|
||||
if exists(resumable_with_seed):
|
||||
orig_epoch_step = len(train_dataloader)
|
||||
start_step = start_update * self.grad_accumulation_steps
|
||||
skipped_epoch = int(start_step // orig_epoch_step)
|
||||
skipped_batch = start_step % orig_epoch_step
|
||||
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
||||
@@ -276,23 +290,21 @@ class Trainer:
|
||||
for epoch in range(skipped_epoch, self.epochs):
|
||||
self.model.train()
|
||||
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
||||
progress_bar = tqdm(
|
||||
skipped_dataloader,
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
unit="step",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
initial=skipped_batch,
|
||||
total=orig_epoch_step,
|
||||
)
|
||||
progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
|
||||
current_dataloader = skipped_dataloader
|
||||
else:
|
||||
progress_bar = tqdm(
|
||||
train_dataloader,
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
unit="step",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
)
|
||||
progress_bar_initial = 0
|
||||
current_dataloader = train_dataloader
|
||||
|
||||
for batch in progress_bar:
|
||||
progress_bar = tqdm(
|
||||
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
|
||||
desc=f"Epoch {epoch+1}/{self.epochs}",
|
||||
unit="update",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
initial=progress_bar_initial,
|
||||
)
|
||||
|
||||
for batch in current_dataloader:
|
||||
with self.accelerator.accumulate(self.model):
|
||||
text_inputs = batch["text"]
|
||||
mel_spec = batch["mel"].permute(0, 2, 1)
|
||||
@@ -301,7 +313,7 @@ class Trainer:
|
||||
# TODO. add duration predictor training
|
||||
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
||||
dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
|
||||
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
|
||||
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
|
||||
|
||||
loss, cond, pred = self.model(
|
||||
mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
|
||||
@@ -318,18 +330,20 @@ class Trainer:
|
||||
if self.is_main and self.accelerator.sync_gradients:
|
||||
self.ema_model.update()
|
||||
|
||||
global_step += 1
|
||||
global_update += 1
|
||||
progress_bar.update(1)
|
||||
progress_bar.set_postfix(update=str(global_update), loss=loss.item())
|
||||
|
||||
if self.accelerator.is_local_main_process:
|
||||
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
|
||||
self.accelerator.log(
|
||||
{"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
|
||||
)
|
||||
if self.logger == "tensorboard":
|
||||
self.writer.add_scalar("loss", loss.item(), global_step)
|
||||
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
|
||||
self.writer.add_scalar("loss", loss.item(), global_update)
|
||||
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
|
||||
|
||||
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
||||
|
||||
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
||||
self.save_checkpoint(global_step)
|
||||
if global_update % self.save_per_updates == 0:
|
||||
self.save_checkpoint(global_update)
|
||||
|
||||
if self.log_samples and self.accelerator.is_local_main_process:
|
||||
ref_audio_len = mel_lengths[0]
|
||||
@@ -355,12 +369,16 @@ class Trainer:
|
||||
gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
|
||||
|
||||
torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
|
||||
torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
|
||||
torchaudio.save(
|
||||
f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
|
||||
)
|
||||
torchaudio.save(
|
||||
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
|
||||
)
|
||||
|
||||
if global_step % self.last_per_steps == 0:
|
||||
self.save_checkpoint(global_step, last=True)
|
||||
if global_update % self.last_per_updates == 0:
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
self.save_checkpoint(global_step, last=True)
|
||||
self.save_checkpoint(global_update, last=True)
|
||||
|
||||
self.accelerator.end_training()
|
||||
|
||||
@@ -20,13 +20,13 @@ grad_accum = 1
|
||||
mini_batch_frames = frames_per_gpu * grad_accum * gpus
|
||||
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
|
||||
updates_per_epoch = total_hours / mini_batch_hours
|
||||
steps_per_epoch = updates_per_epoch * grad_accum
|
||||
# steps_per_epoch = updates_per_epoch * grad_accum
|
||||
|
||||
# result
|
||||
epochs = wanted_max_updates / updates_per_epoch
|
||||
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
|
||||
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
|
||||
print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
||||
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")
|
||||
|
||||
# others
|
||||
print(f"total {total_hours:.0f} hours")
|
||||
|
||||
@@ -27,7 +27,7 @@ def parse_args():
|
||||
|
||||
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
||||
|
||||
# change save_per_updates , last_per_steps change this value what you need ,
|
||||
# change save_per_updates , last_per_updates change this value what you need ,
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train CFM Model")
|
||||
|
||||
@@ -44,9 +44,9 @@ def parse_args():
|
||||
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
|
||||
parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
||||
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
||||
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
||||
parser.add_argument(
|
||||
@@ -61,7 +61,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--log_samples",
|
||||
action="store_true",
|
||||
help="Log inferenced samples per ckpt save steps",
|
||||
help="Log inferenced samples per ckpt save updates",
|
||||
)
|
||||
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
||||
parser.add_argument(
|
||||
@@ -156,7 +156,7 @@ def main():
|
||||
wandb_run_name=args.exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
log_samples=args.log_samples,
|
||||
last_per_steps=args.last_per_steps,
|
||||
last_per_updates=args.last_per_updates,
|
||||
bnb_optimizer=args.bnb_optimizer,
|
||||
)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def save_settings(
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
@@ -86,7 +86,7 @@ def save_settings(
|
||||
"epochs": epochs,
|
||||
"num_warmup_updates": num_warmup_updates,
|
||||
"save_per_updates": save_per_updates,
|
||||
"last_per_steps": last_per_steps,
|
||||
"last_per_updates": last_per_updates,
|
||||
"finetune": finetune,
|
||||
"file_checkpoint_train": file_checkpoint_train,
|
||||
"tokenizer_type": tokenizer_type,
|
||||
@@ -118,7 +118,7 @@ def load_settings(project_name):
|
||||
"epochs": 100,
|
||||
"num_warmup_updates": 2,
|
||||
"save_per_updates": 300,
|
||||
"last_per_steps": 100,
|
||||
"last_per_updates": 100,
|
||||
"finetune": True,
|
||||
"file_checkpoint_train": "",
|
||||
"tokenizer_type": "pinyin",
|
||||
@@ -138,7 +138,7 @@ def load_settings(project_name):
|
||||
settings["epochs"],
|
||||
settings["num_warmup_updates"],
|
||||
settings["save_per_updates"],
|
||||
settings["last_per_steps"],
|
||||
settings["last_per_updates"],
|
||||
settings["finetune"],
|
||||
settings["file_checkpoint_train"],
|
||||
settings["tokenizer_type"],
|
||||
@@ -154,6 +154,8 @@ def load_settings(project_name):
|
||||
settings["logger"] = "wandb"
|
||||
if "bnb_optimizer" not in settings:
|
||||
settings["bnb_optimizer"] = False
|
||||
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
||||
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
||||
return (
|
||||
settings["exp_name"],
|
||||
settings["learning_rate"],
|
||||
@@ -165,7 +167,7 @@ def load_settings(project_name):
|
||||
settings["epochs"],
|
||||
settings["num_warmup_updates"],
|
||||
settings["save_per_updates"],
|
||||
settings["last_per_steps"],
|
||||
settings["last_per_updates"],
|
||||
settings["finetune"],
|
||||
settings["file_checkpoint_train"],
|
||||
settings["tokenizer_type"],
|
||||
@@ -379,7 +381,7 @@ def start_training(
|
||||
epochs=11,
|
||||
num_warmup_updates=200,
|
||||
save_per_updates=400,
|
||||
last_per_steps=800,
|
||||
last_per_updates=800,
|
||||
finetune=True,
|
||||
file_checkpoint_train="",
|
||||
tokenizer_type="pinyin",
|
||||
@@ -448,7 +450,7 @@ def start_training(
|
||||
f"--epochs {epochs} "
|
||||
f"--num_warmup_updates {num_warmup_updates} "
|
||||
f"--save_per_updates {save_per_updates} "
|
||||
f"--last_per_steps {last_per_steps} "
|
||||
f"--last_per_updates {last_per_updates} "
|
||||
f"--dataset_name {dataset_name}"
|
||||
)
|
||||
|
||||
@@ -482,7 +484,7 @@ def start_training(
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
@@ -880,7 +882,7 @@ def calculate_train(
|
||||
learning_rate,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
):
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
@@ -892,7 +894,7 @@ def calculate_train(
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
"project not found !",
|
||||
learning_rate,
|
||||
)
|
||||
@@ -940,14 +942,14 @@ def calculate_train(
|
||||
|
||||
num_warmup_updates = int(samples * 0.05)
|
||||
save_per_updates = int(samples * 0.10)
|
||||
last_per_steps = int(save_per_updates * 0.25)
|
||||
last_per_updates = int(save_per_updates * 0.25)
|
||||
|
||||
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
|
||||
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
|
||||
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
|
||||
last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
|
||||
if last_per_steps <= 0:
|
||||
last_per_steps = 2
|
||||
last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
|
||||
if last_per_updates <= 0:
|
||||
last_per_updates = 2
|
||||
|
||||
total_hours = hours
|
||||
mel_hop_length = 256
|
||||
@@ -978,7 +980,7 @@ def calculate_train(
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
samples,
|
||||
learning_rate,
|
||||
int(epochs),
|
||||
@@ -1530,7 +1532,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
|
||||
|
||||
with gr.TabItem("Train Data"):
|
||||
gr.Markdown("""```plaintext
|
||||
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per steps are set correctly, or change them manually as needed.
|
||||
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
|
||||
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
|
||||
```""")
|
||||
with gr.Row():
|
||||
@@ -1561,7 +1563,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
|
||||
with gr.Row():
|
||||
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
||||
last_per_steps = gr.Number(label="Last per Steps", value=100)
|
||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||
|
||||
with gr.Row():
|
||||
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
||||
@@ -1582,7 +1584,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
epochsv,
|
||||
num_warmupv_updatesv,
|
||||
save_per_updatesv,
|
||||
last_per_stepsv,
|
||||
last_per_updatesv,
|
||||
finetunev,
|
||||
file_checkpoint_trainv,
|
||||
tokenizer_typev,
|
||||
@@ -1601,7 +1603,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
epochs.value = epochsv
|
||||
num_warmup_updates.value = num_warmupv_updatesv
|
||||
save_per_updates.value = save_per_updatesv
|
||||
last_per_steps.value = last_per_stepsv
|
||||
last_per_updates.value = last_per_updatesv
|
||||
ch_finetune.value = finetunev
|
||||
file_checkpoint_train.value = file_checkpoint_trainv
|
||||
tokenizer_type.value = tokenizer_typev
|
||||
@@ -1659,7 +1661,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
ch_finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
@@ -1682,7 +1684,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
learning_rate,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
ch_finetune,
|
||||
],
|
||||
outputs=[
|
||||
@@ -1690,7 +1692,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
max_samples,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
lb_samples,
|
||||
learning_rate,
|
||||
epochs,
|
||||
@@ -1713,7 +1715,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
last_per_steps,
|
||||
last_per_updates,
|
||||
ch_finetune,
|
||||
file_checkpoint_train,
|
||||
tokenizer_type,
|
||||
|
||||
@@ -55,7 +55,7 @@ def main(cfg):
|
||||
wandb_project="CFM-TTS",
|
||||
wandb_run_name=exp_name,
|
||||
wandb_resume_id=wandb_resume_id,
|
||||
last_per_steps=cfg.ckpts.last_per_steps,
|
||||
last_per_updates=cfg.ckpts.last_per_updates,
|
||||
log_samples=True,
|
||||
bnb_optimizer=cfg.optim.bnb_optimizer,
|
||||
mel_spec_type=mel_spec_type,
|
||||
|
||||
Reference in New Issue
Block a user