mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-26 12:51:16 -08:00
Keep Last N Checkpoints (#718)
* Add checkpoint management feature - Introduced `keep_last_n_checkpoints` parameter in configuration and training scripts to manage the number of recent checkpoints retained. - Updated `finetune_cli.py`, `finetune_gradio.py`, and `trainer.py` to support this new parameter. - Implemented logic to remove older checkpoints beyond the specified limit during training. - Adjusted settings loading and saving to include the new checkpoint management option. This enhancement improves the training process by preventing excessive storage usage from old checkpoints.
This commit is contained in:
committed by
GitHub
parent
83efc3f038
commit
76b1b03c4d
@@ -41,4 +41,5 @@ ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -41,4 +41,5 @@ ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -44,4 +44,5 @@ ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -44,4 +44,5 @@ ckpts:
|
||||
logger: wandb # wandb | tensorboard | None
|
||||
save_per_updates: 50000 # save checkpoint per updates
|
||||
last_per_updates: 5000 # save last checkpoint per updates
|
||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||
@@ -50,7 +50,17 @@ class Trainer:
|
||||
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
||||
is_local_vocoder: bool = False, # use local path vocoder
|
||||
local_vocoder_path: str = "", # local vocoder path
|
||||
keep_last_n_checkpoints: int
|
||||
| None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
|
||||
):
|
||||
# Validate keep_last_n_checkpoints
|
||||
if not isinstance(keep_last_n_checkpoints, int):
|
||||
raise ValueError("keep_last_n_checkpoints must be an integer")
|
||||
if keep_last_n_checkpoints < -1:
|
||||
raise ValueError(
|
||||
"keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
|
||||
)
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
|
||||
if logger == "wandb" and not wandb.api.api_key:
|
||||
@@ -134,6 +144,8 @@ class Trainer:
|
||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
||||
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||
|
||||
self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None
|
||||
|
||||
@property
|
||||
def is_main(self):
|
||||
return self.accelerator.is_main_process
|
||||
@@ -154,7 +166,26 @@ class Trainer:
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
||||
print(f"Saved last checkpoint at update {update}")
|
||||
else:
|
||||
# Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
|
||||
if self.keep_last_n_checkpoints == 0:
|
||||
return
|
||||
|
||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
||||
# Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
|
||||
if self.keep_last_n_checkpoints > 0:
|
||||
# Get all checkpoint files except model_last.pt
|
||||
checkpoints = [
|
||||
f
|
||||
for f in os.listdir(self.checkpoint_path)
|
||||
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
|
||||
]
|
||||
# Sort by step number
|
||||
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||
# Remove old checkpoints if we have more than keep_last_n_checkpoints
|
||||
while len(checkpoints) > self.keep_last_n_checkpoints:
|
||||
oldest_checkpoint = checkpoints.pop(0)
|
||||
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
|
||||
print(f"Removed old checkpoint: {oldest_checkpoint}")
|
||||
|
||||
def load_checkpoint(self):
|
||||
if (
|
||||
|
||||
@@ -69,6 +69,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Use 8-bit Adam optimizer from bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_last_n_checkpoints",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -158,6 +164,7 @@ def main():
|
||||
log_samples=args.log_samples,
|
||||
last_per_updates=args.last_per_updates,
|
||||
bnb_optimizer=args.bnb_optimizer,
|
||||
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||
|
||||
@@ -70,6 +70,7 @@ def save_settings(
|
||||
mixed_precision,
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
keep_last_n_checkpoints,
|
||||
):
|
||||
path_project = os.path.join(path_project_ckpts, project_name)
|
||||
os.makedirs(path_project, exist_ok=True)
|
||||
@@ -94,6 +95,7 @@ def save_settings(
|
||||
"mixed_precision": mixed_precision,
|
||||
"logger": logger,
|
||||
"bnb_optimizer": ch_8bit_adam,
|
||||
"keep_last_n_checkpoints": keep_last_n_checkpoints,
|
||||
}
|
||||
with open(file_setting, "w") as f:
|
||||
json.dump(settings, f, indent=4)
|
||||
@@ -126,6 +128,7 @@ def load_settings(project_name):
|
||||
"mixed_precision": "none",
|
||||
"logger": "wandb",
|
||||
"bnb_optimizer": False,
|
||||
"keep_last_n_checkpoints": -1, # Default to keep all checkpoints
|
||||
}
|
||||
return (
|
||||
settings["exp_name"],
|
||||
@@ -146,6 +149,7 @@ def load_settings(project_name):
|
||||
settings["mixed_precision"],
|
||||
settings["logger"],
|
||||
settings["bnb_optimizer"],
|
||||
settings["keep_last_n_checkpoints"],
|
||||
)
|
||||
|
||||
with open(file_setting, "r") as f:
|
||||
@@ -154,6 +158,8 @@ def load_settings(project_name):
|
||||
settings["logger"] = "wandb"
|
||||
if "bnb_optimizer" not in settings:
|
||||
settings["bnb_optimizer"] = False
|
||||
if "keep_last_n_checkpoints" not in settings:
|
||||
settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
|
||||
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
||||
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
||||
return (
|
||||
@@ -175,6 +181,7 @@ def load_settings(project_name):
|
||||
settings["mixed_precision"],
|
||||
settings["logger"],
|
||||
settings["bnb_optimizer"],
|
||||
settings["keep_last_n_checkpoints"],
|
||||
)
|
||||
|
||||
|
||||
@@ -390,6 +397,7 @@ def start_training(
|
||||
stream=False,
|
||||
logger="wandb",
|
||||
ch_8bit_adam=False,
|
||||
keep_last_n_checkpoints=-1,
|
||||
):
|
||||
global training_process, tts_api, stop_signal
|
||||
|
||||
@@ -451,7 +459,8 @@ def start_training(
|
||||
f"--num_warmup_updates {num_warmup_updates} "
|
||||
f"--save_per_updates {save_per_updates} "
|
||||
f"--last_per_updates {last_per_updates} "
|
||||
f"--dataset_name {dataset_name}"
|
||||
f"--dataset_name {dataset_name} "
|
||||
f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
|
||||
)
|
||||
|
||||
if finetune:
|
||||
@@ -492,6 +501,7 @@ def start_training(
|
||||
mixed_precision,
|
||||
logger,
|
||||
ch_8bit_adam,
|
||||
keep_last_n_checkpoints,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1564,6 +1574,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
with gr.Row():
|
||||
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||
keep_last_n_checkpoints = gr.Number(
|
||||
label="Keep Last N Checkpoints",
|
||||
value=-1,
|
||||
step=1,
|
||||
precision=0,
|
||||
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
||||
@@ -1592,6 +1609,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
mixed_precisionv,
|
||||
cd_loggerv,
|
||||
ch_8bit_adamv,
|
||||
keep_last_n_checkpointsv,
|
||||
) = load_settings(projects_selelect)
|
||||
exp_name.value = exp_namev
|
||||
learning_rate.value = learning_ratev
|
||||
@@ -1611,6 +1629,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
mixed_precision.value = mixed_precisionv
|
||||
cd_logger.value = cd_loggerv
|
||||
ch_8bit_adam.value = ch_8bit_adamv
|
||||
keep_last_n_checkpoints.value = keep_last_n_checkpointsv
|
||||
|
||||
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
||||
txt_info_train = gr.Text(label="Info", value="")
|
||||
@@ -1670,6 +1689,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
||||
ch_stream,
|
||||
cd_logger,
|
||||
ch_8bit_adam,
|
||||
keep_last_n_checkpoints,
|
||||
],
|
||||
outputs=[txt_info_train, start_button, stop_button],
|
||||
)
|
||||
|
||||
@@ -61,6 +61,7 @@ def main(cfg):
|
||||
mel_spec_type=mel_spec_type,
|
||||
is_local_vocoder=cfg.model.vocoder.is_local,
|
||||
local_vocoder_path=cfg.model.vocoder.local_path,
|
||||
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
||||
|
||||
Reference in New Issue
Block a user