Keep Last N Checkpoints (#718)

* Add checkpoint management feature

- Introduced `keep_last_n_checkpoints` parameter in configuration and training scripts to manage the number of recent checkpoints retained.
- Updated `finetune_cli.py`, `finetune_gradio.py`, and `trainer.py` to support this new parameter.
- Implemented logic to remove older checkpoints beyond the specified limit during training.
- Adjusted settings loading and saving to include the new checkpoint management option.

This enhancement improves the training process by preventing excessive storage usage from old checkpoints.
This commit is contained in:
Hasan Can Solakoğlu
2025-01-15 07:28:54 +03:00
committed by GitHub
parent 83efc3f038
commit 76b1b03c4d
8 changed files with 64 additions and 1 deletions

View File

@@ -41,4 +41,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -41,4 +41,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -44,4 +44,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -44,4 +44,5 @@ ckpts:
logger: wandb # wandb | tensorboard | None
save_per_updates: 50000 # save checkpoint per updates
last_per_updates: 5000 # save last checkpoint per updates
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -50,7 +50,17 @@ class Trainer:
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
keep_last_n_checkpoints: int
| None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
):
# Validate keep_last_n_checkpoints
if not isinstance(keep_last_n_checkpoints, int):
raise ValueError("keep_last_n_checkpoints must be an integer")
if keep_last_n_checkpoints < -1:
raise ValueError(
"keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if logger == "wandb" and not wandb.api.api_key:
@@ -134,6 +144,8 @@ class Trainer:
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None
@property
def is_main(self):
return self.accelerator.is_main_process
@@ -154,7 +166,26 @@ class Trainer:
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
print(f"Saved last checkpoint at update {update}")
else:
# Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
if self.keep_last_n_checkpoints == 0:
return
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
# Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
if self.keep_last_n_checkpoints > 0:
# Get all checkpoint files except model_last.pt
checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
]
# Sort by step number
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
# Remove old checkpoints if we have more than keep_last_n_checkpoints
while len(checkpoints) > self.keep_last_n_checkpoints:
oldest_checkpoint = checkpoints.pop(0)
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
print(f"Removed old checkpoint: {oldest_checkpoint}")
def load_checkpoint(self):
if (

View File

@@ -69,6 +69,12 @@ def parse_args():
action="store_true",
help="Use 8-bit Adam optimizer from bitsandbytes",
)
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
)
return parser.parse_args()
@@ -158,6 +164,7 @@ def main():
log_samples=args.log_samples,
last_per_updates=args.last_per_updates,
bnb_optimizer=args.bnb_optimizer,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
)
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)

View File

@@ -70,6 +70,7 @@ def save_settings(
mixed_precision,
logger,
ch_8bit_adam,
keep_last_n_checkpoints,
):
path_project = os.path.join(path_project_ckpts, project_name)
os.makedirs(path_project, exist_ok=True)
@@ -94,6 +95,7 @@ def save_settings(
"mixed_precision": mixed_precision,
"logger": logger,
"bnb_optimizer": ch_8bit_adam,
"keep_last_n_checkpoints": keep_last_n_checkpoints,
}
with open(file_setting, "w") as f:
json.dump(settings, f, indent=4)
@@ -126,6 +128,7 @@ def load_settings(project_name):
"mixed_precision": "none",
"logger": "wandb",
"bnb_optimizer": False,
"keep_last_n_checkpoints": -1, # Default to keep all checkpoints
}
return (
settings["exp_name"],
@@ -146,6 +149,7 @@ def load_settings(project_name):
settings["mixed_precision"],
settings["logger"],
settings["bnb_optimizer"],
settings["keep_last_n_checkpoints"],
)
with open(file_setting, "r") as f:
@@ -154,6 +158,8 @@ def load_settings(project_name):
settings["logger"] = "wandb"
if "bnb_optimizer" not in settings:
settings["bnb_optimizer"] = False
if "keep_last_n_checkpoints" not in settings:
settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
return (
@@ -175,6 +181,7 @@ def load_settings(project_name):
settings["mixed_precision"],
settings["logger"],
settings["bnb_optimizer"],
settings["keep_last_n_checkpoints"],
)
@@ -390,6 +397,7 @@ def start_training(
stream=False,
logger="wandb",
ch_8bit_adam=False,
keep_last_n_checkpoints=-1,
):
global training_process, tts_api, stop_signal
@@ -451,7 +459,8 @@ def start_training(
f"--num_warmup_updates {num_warmup_updates} "
f"--save_per_updates {save_per_updates} "
f"--last_per_updates {last_per_updates} "
f"--dataset_name {dataset_name}"
f"--dataset_name {dataset_name} "
f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
)
if finetune:
@@ -492,6 +501,7 @@ def start_training(
mixed_precision,
logger,
ch_8bit_adam,
keep_last_n_checkpoints,
)
try:
@@ -1564,6 +1574,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=300)
last_per_updates = gr.Number(label="Last per Updates", value=100)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=-1,
step=1,
precision=0,
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
)
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
@@ -1592,6 +1609,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
mixed_precisionv,
cd_loggerv,
ch_8bit_adamv,
keep_last_n_checkpointsv,
) = load_settings(projects_selelect)
exp_name.value = exp_namev
learning_rate.value = learning_ratev
@@ -1611,6 +1629,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
mixed_precision.value = mixed_precisionv
cd_logger.value = cd_loggerv
ch_8bit_adam.value = ch_8bit_adamv
keep_last_n_checkpoints.value = keep_last_n_checkpointsv
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
txt_info_train = gr.Text(label="Info", value="")
@@ -1670,6 +1689,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
ch_stream,
cd_logger,
ch_8bit_adam,
keep_last_n_checkpoints,
],
outputs=[txt_info_train, start_button, stop_button],
)

View File

@@ -61,6 +61,7 @@ def main(cfg):
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)