0.4.1 #718 add keep_last_n_checkpoints option

This commit is contained in:
SWivid
2025-01-15 15:06:55 +08:00
parent 76b1b03c4d
commit 12d6970271
9 changed files with 68 additions and 143 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.4.0"
version = "0.4.1"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}

View File

@@ -40,6 +40,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -40,6 +40,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -43,6 +43,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -43,6 +43,6 @@ model:
ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -30,6 +30,7 @@ class Trainer:
learning_rate,
num_warmup_updates=20000,
save_per_updates=1000,
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
checkpoint_path=None,
batch_size=32,
batch_size_type: str = "sample",
@@ -50,17 +51,7 @@ class Trainer:
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
keep_last_n_checkpoints: int
| None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
):
# Validate keep_last_n_checkpoints
if not isinstance(keep_last_n_checkpoints, int):
raise ValueError("keep_last_n_checkpoints must be an integer")
if keep_last_n_checkpoints < -1:
raise ValueError(
"keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if logger == "wandb" and not wandb.api.api_key:
@@ -118,6 +109,7 @@ class Trainer:
self.epochs = epochs
self.num_warmup_updates = num_warmup_updates
self.save_per_updates = save_per_updates
self.keep_last_n_checkpoints = keep_last_n_checkpoints
self.last_per_updates = default(last_per_updates, save_per_updates)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
@@ -144,8 +136,6 @@ class Trainer:
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None
@property
def is_main(self):
return self.accelerator.is_main_process
@@ -166,22 +156,16 @@ class Trainer:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at update {update}")
else:
# Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
if self.keep_last_n_checkpoints == 0:
return
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
# Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
if self.keep_last_n_checkpoints > 0:
# Get all checkpoint files except model_last.pt
checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
]
# Sort by step number
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
# Remove old checkpoints if we have more than keep_last_n_checkpoints
while len(checkpoints) > self.keep_last_n_checkpoints:
oldest_checkpoint = checkpoints.pop(0)
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))

View File

@@ -46,6 +46,12 @@ def parse_args():
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
@@ -69,12 +75,6 @@ def parse_args():
action="store_true",
help="Use 8-bit Adam optimizer from bitsandbytes",
)
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
)
return parser.parse_args()
@@ -151,6 +151,7 @@ def main():
args.learning_rate,
num_warmup_updates=args.num_warmup_updates,
save_per_updates=args.save_per_updates,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
checkpoint_path=checkpoint_path,
batch_size=args.batch_size_per_gpu,
batch_size_type=args.batch_size_type,
@@ -164,7 +165,6 @@ def main():
log_samples=args.log_samples,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
)
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)

View File

@@ -62,6 +62,7 @@ def save_settings(
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
@@ -70,7 +71,6 @@ def save_settings(
mixed_precision,
logger,
ch_8bit_adam,
keep_last_n_checkpoints,
):
path_project = os.path.join(path_project_ckpts, project_name)
os.makedirs(path_project, exist_ok=True)
@@ -87,6 +87,7 @@ def save_settings(
"epochs": epochs,
"num_warmup_updates": num_warmup_updates,
"save_per_updates": save_per_updates,
"keep_last_n_checkpoints": keep_last_n_checkpoints,
"last_per_updates": last_per_updates,
"finetune": finetune,
"file_checkpoint_train": file_checkpoint_train,
@@ -95,7 +96,6 @@ def save_settings(
"mixed_precision": mixed_precision,
"logger": logger,
"bnb_optimizer": ch_8bit_adam,
"keep_last_n_checkpoints": keep_last_n_checkpoints,
}
with open(file_setting, "w") as f:
json.dump(settings, f, indent=4)
@@ -120,6 +120,7 @@ def load_settings(project_name):
"epochs": 100,
"num_warmup_updates": 2,
"save_per_updates": 300,
"keep_last_n_checkpoints": -1,
"last_per_updates": 100,
"finetune": True,
"file_checkpoint_train": "",
@@ -128,30 +129,8 @@ def load_settings(project_name):
"mixed_precision": "none",
"logger": "wandb",
"bnb_optimizer": False,
"keep_last_n_checkpoints": -1, # Default to keep all checkpoints
}
return (
settings["exp_name"],
settings["learning_rate"],
settings["batch_size_per_gpu"],
settings["batch_size_type"],
settings["max_samples"],
settings["grad_accumulation_steps"],
settings["max_grad_norm"],
settings["epochs"],
settings["num_warmup_updates"],
settings["save_per_updates"],
settings["last_per_updates"],
settings["finetune"],
settings["file_checkpoint_train"],
settings["tokenizer_type"],
settings["tokenizer_file"],
settings["mixed_precision"],
settings["logger"],
settings["bnb_optimizer"],
settings["keep_last_n_checkpoints"],
)
else:
with open(file_setting, "r") as f:
settings = json.load(f)
if "logger" not in settings:
@@ -159,30 +138,11 @@ def load_settings(project_name):
if "bnb_optimizer" not in settings:
settings["bnb_optimizer"] = False
if "keep_last_n_checkpoints" not in settings:
settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
settings["keep_last_n_checkpoints"] = -1 # default to keep all checkpoints
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
return (
settings["exp_name"],
settings["learning_rate"],
settings["batch_size_per_gpu"],
settings["batch_size_type"],
settings["max_samples"],
settings["grad_accumulation_steps"],
settings["max_grad_norm"],
settings["epochs"],
settings["num_warmup_updates"],
settings["save_per_updates"],
settings["last_per_updates"],
settings["finetune"],
settings["file_checkpoint_train"],
settings["tokenizer_type"],
settings["tokenizer_file"],
settings["mixed_precision"],
settings["logger"],
settings["bnb_optimizer"],
settings["keep_last_n_checkpoints"],
)
return settings
# Load metadata
@@ -388,6 +348,7 @@ def start_training(
epochs=11,
num_warmup_updates=200,
save_per_updates=400,
keep_last_n_checkpoints=-1,
last_per_updates=800,
finetune=True,
file_checkpoint_train="",
@@ -397,7 +358,6 @@ def start_training(
stream=False,
logger="wandb",
ch_8bit_adam=False,
keep_last_n_checkpoints=-1,
):
global training_process, tts_api, stop_signal
@@ -458,9 +418,9 @@ def start_training(
f" --epochs {epochs}"
f" --num_warmup_updates {num_warmup_updates}"
f" --save_per_updates {save_per_updates}"
f" --keep_last_n_checkpoints {keep_last_n_checkpoints}"
f" --last_per_updates {last_per_updates}"
f" --dataset_name {dataset_name}"
f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
)
if finetune:
@@ -493,6 +453,7 @@ def start_training(
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
@@ -501,7 +462,6 @@ def start_training(
mixed_precision,
logger,
ch_8bit_adam,
keep_last_n_checkpoints,
)
try:
@@ -1573,7 +1533,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=300)
last_per_updates = gr.Number(label="Last per Updates", value=100)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=-1,
@@ -1581,6 +1540,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
precision=0,
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
)
last_per_updates = gr.Number(label="Last per Updates", value=100)
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
@@ -1590,46 +1550,27 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
stop_button = gr.Button("Stop Training", interactive=False)
if projects_selelect is not None:
(
exp_namev,
learning_ratev,
batch_size_per_gpuv,
batch_size_typev,
max_samplesv,
grad_accumulation_stepsv,
max_grad_normv,
epochsv,
num_warmupv_updatesv,
save_per_updatesv,
last_per_updatesv,
finetunev,
file_checkpoint_trainv,
tokenizer_typev,
tokenizer_filev,
mixed_precisionv,
cd_loggerv,
ch_8bit_adamv,
keep_last_n_checkpointsv,
) = load_settings(projects_selelect)
exp_name.value = exp_namev
learning_rate.value = learning_ratev
batch_size_per_gpu.value = batch_size_per_gpuv
batch_size_type.value = batch_size_typev
max_samples.value = max_samplesv
grad_accumulation_steps.value = grad_accumulation_stepsv
max_grad_norm.value = max_grad_normv
epochs.value = epochsv
num_warmup_updates.value = num_warmupv_updatesv
save_per_updates.value = save_per_updatesv
last_per_updates.value = last_per_updatesv
ch_finetune.value = finetunev
file_checkpoint_train.value = file_checkpoint_trainv
tokenizer_type.value = tokenizer_typev
tokenizer_file.value = tokenizer_filev
mixed_precision.value = mixed_precisionv
cd_logger.value = cd_loggerv
ch_8bit_adam.value = ch_8bit_adamv
keep_last_n_checkpoints.value = keep_last_n_checkpointsv
settings = load_settings(projects_selelect)
exp_name.value = settings["exp_name"]
learning_rate.value = settings["learning_rate"]
batch_size_per_gpu.value = settings["batch_size_per_gpu"]
batch_size_type.value = settings["batch_size_type"]
max_samples.value = settings["max_samples"]
grad_accumulation_steps.value = settings["grad_accumulation_steps"]
max_grad_norm.value = settings["max_grad_norm"]
epochs.value = settings["epochs"]
num_warmup_updates.value = settings["num_warmup_updates"]
save_per_updates.value = settings["save_per_updates"]
keep_last_n_checkpoints.value = settings["keep_last_n_checkpoints"]
last_per_updates.value = settings["last_per_updates"]
ch_finetune.value = settings["finetune"]
file_checkpoint_train.value = settings["file_checkpoint_train"]
tokenizer_type.value = settings["tokenizer_type"]
tokenizer_file.value = settings["tokenizer_file"]
mixed_precision.value = settings["mixed_precision"]
cd_logger.value = settings["logger"]
ch_8bit_adam.value = settings["bnb_optimizer"]
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
txt_info_train = gr.Text(label="Info", value="")
@@ -1680,6 +1621,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
ch_finetune,
file_checkpoint_train,
@@ -1689,7 +1631,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
ch_stream,
cd_logger,
ch_8bit_adam,
keep_last_n_checkpoints,
],
outputs=[txt_info_train, start_button, stop_button],
)

View File

@@ -45,6 +45,7 @@ def main(cfg):
learning_rate=cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1),
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
batch_size=cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type,
@@ -61,7 +62,6 @@ def main(cfg):
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)