mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-20 22:33:36 -08:00
0.4.1 #718 add keep_last_n_checkpoints option
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "0.4.0"
|
||||
version = "0.4.1"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -40,6 +40,6 @@ model:
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -40,6 +40,6 @@ model:
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -43,6 +43,6 @@ model:
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -43,6 +43,6 @@ model:
|
||||
ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -30,6 +30,7 @@ class Trainer:
|
||||
learning_rate,
|
||||
num_warmup_updates=20000,
|
||||
save_per_updates=1000,
|
||||
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||
checkpoint_path=None,
|
||||
batch_size=32,
|
||||
batch_size_type: str = "sample",
|
||||
@@ -50,17 +51,7 @@ class Trainer:
|
||||
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
||||
is_local_vocoder: bool = False, # use local path vocoder
|
||||
local_vocoder_path: str = "", # local vocoder path
|
||||
keep_last_n_checkpoints: int
|
||||
| None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
|
||||
):
|
||||
# Validate keep_last_n_checkpoints
|
||||
if not isinstance(keep_last_n_checkpoints, int):
|
||||
raise ValueError("keep_last_n_checkpoints must be an integer")
|
||||
if keep_last_n_checkpoints < -1:
|
||||
raise ValueError(
|
||||
"keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
|
||||
)
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
|
||||
if logger == "wandb" and not wandb.api.api_key:
|
||||
@@ -118,6 +109,7 @@ class Trainer:
|
||||
self.epochs = epochs
|
||||
self.num_warmup_updates = num_warmup_updates
|
||||
self.save_per_updates = save_per_updates
|
||||
self.keep_last_n_checkpoints = keep_last_n_checkpoints
|
||||
self.last_per_updates = default(last_per_updates, save_per_updates)
|
||||
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
||||
|
||||
@@ -144,8 +136,6 @@ class Trainer:
|
||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
||||
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
|
||||
self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None
|
||||
|
||||
@property
|
||||
def is_main(self):
|
||||
return self.accelerator.is_main_process
|
||||
@@ -166,22 +156,16 @@ class Trainer:
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
||||
print(f"Saved last checkpoint at update {update}")
|
||||
else:
|
||||
# Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
|
||||
if self.keep_last_n_checkpoints == 0:
|
||||
return
|
||||
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
||||
# Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
|
||||
if self.keep_last_n_checkpoints > 0:
|
||||
# Get all checkpoint files except model_last.pt
|
||||
checkpoints = [
|
||||
f
|
||||
for f in os.listdir(self.checkpoint_path)
|
||||
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
|
||||
]
|
||||
# Sort by step number
|
||||
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||
# Remove old checkpoints if we have more than keep_last_n_checkpoints
|
||||
while len(checkpoints) > self.keep_last_n_checkpoints:
|
||||
oldest_checkpoint = checkpoints.pop(0)
|
||||
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
|
||||
|
||||
@@ -46,6 +46,12 @@ def parse_args():
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
||||
parser.add_argument(
|
||||
"--keep_last_n_checkpoints",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||
)
|
||||
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
||||
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
||||
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
||||
@@ -69,12 +75,6 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Use 8-bit Adam optimizer from bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_last_n_checkpoints",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -151,6 +151,7 @@ def main():
|
||||
args.learning_rate,
|
||||
num_warmup_updates=args.num_warmup_updates,
|
||||
save_per_updates=args.save_per_updates,
|
||||
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
||||
checkpoint_path=checkpoint_path,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
batch_size_type=args.batch_size_type,
|
||||
@@ -164,7 +165,6 @@ def main():
|
||||
log_samples=args.log_samples,
|
||||
last_per_updates=args.last_per_updates,
|
||||
bnb_optimizer=args.bnb_optimizer,
|
||||
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
|
||||
@@ -62,6 +62,7 @@ def save_settings(
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
keep_last_n_checkpoints,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
file_checkpoint_train,
|
||||
@@ -70,7 +71,6 @@ def save_settings(
|
||||
mixed_precision,
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
keep_last_n_checkpoints,
|
||||
):
|
||||
path_project = os.path.join(path_project_ckpts, project_name)
|
||||
os.makedirs(path_project, exist_ok=True)
|
||||
@@ -87,6 +87,7 @@ def save_settings(
|
||||
"epochs": epochs,
|
||||
"num_warmup_updates": num_warmup_updates,
|
||||
"save_per_updates": save_per_updates,
|
||||
"keep_last_n_checkpoints": keep_last_n_checkpoints,
|
||||
"last_per_updates": last_per_updates,
|
||||
"finetune": finetune,
|
||||
"file_checkpoint_train": file_checkpoint_train,
|
||||
@@ -95,7 +96,6 @@ def save_settings(
|
||||
"mixed_precision": mixed_precision,
|
||||
"logger": logger,
|
||||
"bnb_optimizer": ch_8bit_adam,
|
||||
"keep_last_n_checkpoints": keep_last_n_checkpoints,
|
||||
}
|
||||
with open(file_setting, "w") as f:
|
||||
json.dump(settings, f, indent=4)
|
||||
@@ -120,6 +120,7 @@ def load_settings(project_name):
|
||||
"epochs": 100,
|
||||
"num_warmup_updates": 2,
|
||||
"save_per_updates": 300,
|
||||
"keep_last_n_checkpoints": -1,
|
||||
"last_per_updates": 100,
|
||||
"finetune": True,
|
||||
"file_checkpoint_train": "",
|
||||
@@ -128,61 +129,20 @@ def load_settings(project_name):
|
||||
"mixed_precision": "none",
|
||||
"logger": "wandb",
|
||||
"bnb_optimizer": False,
|
||||
"keep_last_n_checkpoints": -1, # Default to keep all checkpoints
|
||||
}
|
||||
return (
|
||||
settings["exp_name"],
|
||||
settings["learning_rate"],
|
||||
settings["batch_size_per_gpu"],
|
||||
settings["batch_size_type"],
|
||||
settings["max_samples"],
|
||||
settings["grad_accumulation_steps"],
|
||||
settings["max_grad_norm"],
|
||||
settings["epochs"],
|
||||
settings["num_warmup_updates"],
|
||||
settings["save_per_updates"],
|
||||
settings["last_per_updates"],
|
||||
settings["finetune"],
|
||||
settings["file_checkpoint_train"],
|
||||
settings["tokenizer_type"],
|
||||
settings["tokenizer_file"],
|
||||
settings["mixed_precision"],
|
||||
settings["logger"],
|
||||
settings["bnb_optimizer"],
|
||||
settings["keep_last_n_checkpoints"],
|
||||
)
|
||||
else:
|
||||
with open(file_setting, "r") as f:
|
||||
settings = json.load(f)
|
||||
if "logger" not in settings:
|
||||
settings["logger"] = "wandb"
|
||||
if "bnb_optimizer" not in settings:
|
||||
settings["bnb_optimizer"] = False
|
||||
if "keep_last_n_checkpoints" not in settings:
|
||||
settings["keep_last_n_checkpoints"] = -1 # default to keep all checkpoints
|
||||
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
||||
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
||||
|
||||
with open(file_setting, "r") as f:
|
||||
settings = json.load(f)
|
||||
if "logger" not in settings:
|
||||
settings["logger"] = "wandb"
|
||||
if "bnb_optimizer" not in settings:
|
||||
settings["bnb_optimizer"] = False
|
||||
if "keep_last_n_checkpoints" not in settings:
|
||||
settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
|
||||
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
||||
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
||||
return (
|
||||
settings["exp_name"],
|
||||
settings["learning_rate"],
|
||||
settings["batch_size_per_gpu"],
|
||||
settings["batch_size_type"],
|
||||
settings["max_samples"],
|
||||
settings["grad_accumulation_steps"],
|
||||
settings["max_grad_norm"],
|
||||
settings["epochs"],
|
||||
settings["num_warmup_updates"],
|
||||
settings["save_per_updates"],
|
||||
settings["last_per_updates"],
|
||||
settings["finetune"],
|
||||
settings["file_checkpoint_train"],
|
||||
settings["tokenizer_type"],
|
||||
settings["tokenizer_file"],
|
||||
settings["mixed_precision"],
|
||||
settings["logger"],
|
||||
settings["bnb_optimizer"],
|
||||
settings["keep_last_n_checkpoints"],
|
||||
)
|
||||
return settings
|
||||
|
||||
|
||||
# Load metadata
|
||||
@@ -388,6 +348,7 @@ def start_training(
|
||||
epochs=11,
|
||||
num_warmup_updates=200,
|
||||
save_per_updates=400,
|
||||
keep_last_n_checkpoints=-1,
|
||||
last_per_updates=800,
|
||||
finetune=True,
|
||||
file_checkpoint_train="",
|
||||
@@ -397,7 +358,6 @@ def start_training(
|
||||
stream=False,
|
||||
logger="wandb",
|
||||
ch_8bit_adam=False,
|
||||
keep_last_n_checkpoints=-1,
|
||||
):
|
||||
global training_process, tts_api, stop_signal
|
||||
|
||||
@@ -448,19 +408,19 @@ def start_training(
|
||||
fp16 = ""
|
||||
|
||||
cmd = (
|
||||
f"accelerate launch {fp16} {file_train} --exp_name {exp_name} "
|
||||
f"--learning_rate {learning_rate} "
|
||||
f"--batch_size_per_gpu {batch_size_per_gpu} "
|
||||
f"--batch_size_type {batch_size_type} "
|
||||
f"--max_samples {max_samples} "
|
||||
f"--grad_accumulation_steps {grad_accumulation_steps} "
|
||||
f"--max_grad_norm {max_grad_norm} "
|
||||
f"--epochs {epochs} "
|
||||
f"--num_warmup_updates {num_warmup_updates} "
|
||||
f"--save_per_updates {save_per_updates} "
|
||||
f"--last_per_updates {last_per_updates} "
|
||||
f"--dataset_name {dataset_name} "
|
||||
f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
|
||||
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
|
||||
f" --learning_rate {learning_rate}"
|
||||
f" --batch_size_per_gpu {batch_size_per_gpu}"
|
||||
f" --batch_size_type {batch_size_type}"
|
||||
f" --max_samples {max_samples}"
|
||||
f" --grad_accumulation_steps {grad_accumulation_steps}"
|
||||
f" --max_grad_norm {max_grad_norm}"
|
||||
f" --epochs {epochs}"
|
||||
f" --num_warmup_updates {num_warmup_updates}"
|
||||
f" --save_per_updates {save_per_updates}"
|
||||
f" --keep_last_n_checkpoints {keep_last_n_checkpoints}"
|
||||
f" --last_per_updates {last_per_updates}"
|
||||
f" --dataset_name {dataset_name}"
|
||||
)
|
||||
|
||||
if finetune:
|
||||
@@ -493,6 +453,7 @@ def start_training(
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
keep_last_n_checkpoints,
|
||||
last_per_updates,
|
||||
finetune,
|
||||
file_checkpoint_train,
|
||||
@@ -501,7 +462,6 @@ def start_training(
|
||||
mixed_precision,
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
keep_last_n_checkpoints,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1573,7 +1533,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
|
||||
with gr.Row():
|
||||
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||
keep_last_n_checkpoints = gr.Number(
|
||||
label="Keep Last N Checkpoints",
|
||||
value=-1,
|
||||
@@ -1581,6 +1540,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
precision=0,
|
||||
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
|
||||
)
|
||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||
|
||||
with gr.Row():
|
||||
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
||||
@@ -1590,46 +1550,27 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
stop_button = gr.Button("Stop Training", interactive=False)
|
||||
|
||||
if projects_selelect is not None:
|
||||
(
|
||||
exp_namev,
|
||||
learning_ratev,
|
||||
batch_size_per_gpuv,
|
||||
batch_size_typev,
|
||||
max_samplesv,
|
||||
grad_accumulation_stepsv,
|
||||
max_grad_normv,
|
||||
epochsv,
|
||||
num_warmupv_updatesv,
|
||||
save_per_updatesv,
|
||||
last_per_updatesv,
|
||||
finetunev,
|
||||
file_checkpoint_trainv,
|
||||
tokenizer_typev,
|
||||
tokenizer_filev,
|
||||
mixed_precisionv,
|
||||
cd_loggerv,
|
||||
ch_8bit_adamv,
|
||||
keep_last_n_checkpointsv,
|
||||
) = load_settings(projects_selelect)
|
||||
exp_name.value = exp_namev
|
||||
learning_rate.value = learning_ratev
|
||||
batch_size_per_gpu.value = batch_size_per_gpuv
|
||||
batch_size_type.value = batch_size_typev
|
||||
max_samples.value = max_samplesv
|
||||
grad_accumulation_steps.value = grad_accumulation_stepsv
|
||||
max_grad_norm.value = max_grad_normv
|
||||
epochs.value = epochsv
|
||||
num_warmup_updates.value = num_warmupv_updatesv
|
||||
save_per_updates.value = save_per_updatesv
|
||||
last_per_updates.value = last_per_updatesv
|
||||
ch_finetune.value = finetunev
|
||||
file_checkpoint_train.value = file_checkpoint_trainv
|
||||
tokenizer_type.value = tokenizer_typev
|
||||
tokenizer_file.value = tokenizer_filev
|
||||
mixed_precision.value = mixed_precisionv
|
||||
cd_logger.value = cd_loggerv
|
||||
ch_8bit_adam.value = ch_8bit_adamv
|
||||
keep_last_n_checkpoints.value = keep_last_n_checkpointsv
|
||||
settings = load_settings(projects_selelect)
|
||||
|
||||
exp_name.value = settings["exp_name"]
|
||||
learning_rate.value = settings["learning_rate"]
|
||||
batch_size_per_gpu.value = settings["batch_size_per_gpu"]
|
||||
batch_size_type.value = settings["batch_size_type"]
|
||||
max_samples.value = settings["max_samples"]
|
||||
grad_accumulation_steps.value = settings["grad_accumulation_steps"]
|
||||
max_grad_norm.value = settings["max_grad_norm"]
|
||||
epochs.value = settings["epochs"]
|
||||
num_warmup_updates.value = settings["num_warmup_updates"]
|
||||
save_per_updates.value = settings["save_per_updates"]
|
||||
keep_last_n_checkpoints.value = settings["keep_last_n_checkpoints"]
|
||||
last_per_updates.value = settings["last_per_updates"]
|
||||
ch_finetune.value = settings["finetune"]
|
||||
file_checkpoint_train.value = settings["file_checkpoint_train"]
|
||||
tokenizer_type.value = settings["tokenizer_type"]
|
||||
tokenizer_file.value = settings["tokenizer_file"]
|
||||
mixed_precision.value = settings["mixed_precision"]
|
||||
cd_logger.value = settings["logger"]
|
||||
ch_8bit_adam.value = settings["bnb_optimizer"]
|
||||
|
||||
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
||||
txt_info_train = gr.Text(label="Info", value="")
|
||||
@@ -1680,6 +1621,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
epochs,
|
||||
num_warmup_updates,
|
||||
save_per_updates,
|
||||
keep_last_n_checkpoints,
|
||||
last_per_updates,
|
||||
ch_finetune,
|
||||
file_checkpoint_train,
|
||||
@@ -1689,7 +1631,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
ch_stream,
|
||||
cd_logger,
|
||||
ch_8bit_adam,
|
||||
keep_last_n_checkpoints,
|
||||
],
|
||||
outputs=[txt_info_train, start_button, stop_button],
|
||||
)
|
||||
|
||||
@@ -45,6 +45,7 @@ def main(cfg):
|
||||
learning_rate=cfg.optim.learning_rate,
|
||||
num_warmup_updates=cfg.optim.num_warmup_updates,
|
||||
save_per_updates=cfg.ckpts.save_per_updates,
|
||||
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1),
|
||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
||||
batch_size=cfg.datasets.batch_size_per_gpu,
|
||||
batch_size_type=cfg.datasets.batch_size_type,
|
||||
@@ -61,7 +62,6 @@ def main(cfg):
|
||||
mel_spec_type=mel_spec_type,
|
||||
is_local_vocoder=cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=cfg.model.vocoder.local_path,
|
||||
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
||||
|
||||
Reference in New Issue
Block a user