mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-20 22:33:36 -08:00
0.4.1 #718 add keep_last_n_checkpoints option
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "f5-tts"
|
name = "f5-tts"
|
||||||
version = "0.4.0"
|
version = "0.4.1"
|
||||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = {text = "MIT License"}
|
license = {text = "MIT License"}
|
||||||
|
|||||||
@@ -40,6 +40,6 @@ model:
|
|||||||
ckpts:
|
ckpts:
|
||||||
logger: wandb # wandb | tensorboard | None
|
logger: wandb # wandb | tensorboard | None
|
||||||
save_per_updates: 50000 # save checkpoint per updates
|
save_per_updates: 50000 # save checkpoint per updates
|
||||||
|
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||||
last_per_updates: 5000 # save last checkpoint per updates
|
last_per_updates: 5000 # save last checkpoint per updates
|
||||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
|
||||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||||
@@ -40,6 +40,6 @@ model:
|
|||||||
ckpts:
|
ckpts:
|
||||||
logger: wandb # wandb | tensorboard | None
|
logger: wandb # wandb | tensorboard | None
|
||||||
save_per_updates: 50000 # save checkpoint per updates
|
save_per_updates: 50000 # save checkpoint per updates
|
||||||
|
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||||
last_per_updates: 5000 # save last checkpoint per updates
|
last_per_updates: 5000 # save last checkpoint per updates
|
||||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
|
||||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||||
@@ -43,6 +43,6 @@ model:
|
|||||||
ckpts:
|
ckpts:
|
||||||
logger: wandb # wandb | tensorboard | None
|
logger: wandb # wandb | tensorboard | None
|
||||||
save_per_updates: 50000 # save checkpoint per updates
|
save_per_updates: 50000 # save checkpoint per updates
|
||||||
|
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||||
last_per_updates: 5000 # save last checkpoint per updates
|
last_per_updates: 5000 # save last checkpoint per updates
|
||||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
|
||||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||||
@@ -43,6 +43,6 @@ model:
|
|||||||
ckpts:
|
ckpts:
|
||||||
logger: wandb # wandb | tensorboard | None
|
logger: wandb # wandb | tensorboard | None
|
||||||
save_per_updates: 50000 # save checkpoint per updates
|
save_per_updates: 50000 # save checkpoint per updates
|
||||||
|
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||||
last_per_updates: 5000 # save last checkpoint per updates
|
last_per_updates: 5000 # save last checkpoint per updates
|
||||||
keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
|
|
||||||
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
||||||
@@ -30,6 +30,7 @@ class Trainer:
|
|||||||
learning_rate,
|
learning_rate,
|
||||||
num_warmup_updates=20000,
|
num_warmup_updates=20000,
|
||||||
save_per_updates=1000,
|
save_per_updates=1000,
|
||||||
|
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
batch_size_type: str = "sample",
|
batch_size_type: str = "sample",
|
||||||
@@ -50,17 +51,7 @@ class Trainer:
|
|||||||
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
||||||
is_local_vocoder: bool = False, # use local path vocoder
|
is_local_vocoder: bool = False, # use local path vocoder
|
||||||
local_vocoder_path: str = "", # local vocoder path
|
local_vocoder_path: str = "", # local vocoder path
|
||||||
keep_last_n_checkpoints: int
|
|
||||||
| None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
|
|
||||||
):
|
):
|
||||||
# Validate keep_last_n_checkpoints
|
|
||||||
if not isinstance(keep_last_n_checkpoints, int):
|
|
||||||
raise ValueError("keep_last_n_checkpoints must be an integer")
|
|
||||||
if keep_last_n_checkpoints < -1:
|
|
||||||
raise ValueError(
|
|
||||||
"keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
|
|
||||||
)
|
|
||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
|
|
||||||
if logger == "wandb" and not wandb.api.api_key:
|
if logger == "wandb" and not wandb.api.api_key:
|
||||||
@@ -118,6 +109,7 @@ class Trainer:
|
|||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.num_warmup_updates = num_warmup_updates
|
self.num_warmup_updates = num_warmup_updates
|
||||||
self.save_per_updates = save_per_updates
|
self.save_per_updates = save_per_updates
|
||||||
|
self.keep_last_n_checkpoints = keep_last_n_checkpoints
|
||||||
self.last_per_updates = default(last_per_updates, save_per_updates)
|
self.last_per_updates = default(last_per_updates, save_per_updates)
|
||||||
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
|
||||||
|
|
||||||
@@ -144,8 +136,6 @@ class Trainer:
|
|||||||
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
||||||
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
||||||
|
|
||||||
self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_main(self):
|
def is_main(self):
|
||||||
return self.accelerator.is_main_process
|
return self.accelerator.is_main_process
|
||||||
@@ -166,22 +156,16 @@ class Trainer:
|
|||||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
||||||
print(f"Saved last checkpoint at update {update}")
|
print(f"Saved last checkpoint at update {update}")
|
||||||
else:
|
else:
|
||||||
# Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
|
|
||||||
if self.keep_last_n_checkpoints == 0:
|
if self.keep_last_n_checkpoints == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
|
||||||
# Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
|
|
||||||
if self.keep_last_n_checkpoints > 0:
|
if self.keep_last_n_checkpoints > 0:
|
||||||
# Get all checkpoint files except model_last.pt
|
|
||||||
checkpoints = [
|
checkpoints = [
|
||||||
f
|
f
|
||||||
for f in os.listdir(self.checkpoint_path)
|
for f in os.listdir(self.checkpoint_path)
|
||||||
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
|
if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
|
||||||
]
|
]
|
||||||
# Sort by step number
|
|
||||||
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
|
||||||
# Remove old checkpoints if we have more than keep_last_n_checkpoints
|
|
||||||
while len(checkpoints) > self.keep_last_n_checkpoints:
|
while len(checkpoints) > self.keep_last_n_checkpoints:
|
||||||
oldest_checkpoint = checkpoints.pop(0)
|
oldest_checkpoint = checkpoints.pop(0)
|
||||||
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
|
os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
|
||||||
|
|||||||
@@ -46,6 +46,12 @@ def parse_args():
|
|||||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||||
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
|
||||||
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep_last_n_checkpoints",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
|
||||||
|
)
|
||||||
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
|
||||||
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
|
||||||
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
||||||
@@ -69,12 +75,6 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use 8-bit Adam optimizer from bitsandbytes",
|
help="Use 8-bit Adam optimizer from bitsandbytes",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--keep_last_n_checkpoints",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@@ -151,6 +151,7 @@ def main():
|
|||||||
args.learning_rate,
|
args.learning_rate,
|
||||||
num_warmup_updates=args.num_warmup_updates,
|
num_warmup_updates=args.num_warmup_updates,
|
||||||
save_per_updates=args.save_per_updates,
|
save_per_updates=args.save_per_updates,
|
||||||
|
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
||||||
checkpoint_path=checkpoint_path,
|
checkpoint_path=checkpoint_path,
|
||||||
batch_size=args.batch_size_per_gpu,
|
batch_size=args.batch_size_per_gpu,
|
||||||
batch_size_type=args.batch_size_type,
|
batch_size_type=args.batch_size_type,
|
||||||
@@ -164,7 +165,6 @@ def main():
|
|||||||
log_samples=args.log_samples,
|
log_samples=args.log_samples,
|
||||||
last_per_updates=args.last_per_updates,
|
last_per_updates=args.last_per_updates,
|
||||||
bnb_optimizer=args.bnb_optimizer,
|
bnb_optimizer=args.bnb_optimizer,
|
||||||
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ def save_settings(
|
|||||||
epochs,
|
epochs,
|
||||||
num_warmup_updates,
|
num_warmup_updates,
|
||||||
save_per_updates,
|
save_per_updates,
|
||||||
|
keep_last_n_checkpoints,
|
||||||
last_per_updates,
|
last_per_updates,
|
||||||
finetune,
|
finetune,
|
||||||
file_checkpoint_train,
|
file_checkpoint_train,
|
||||||
@@ -70,7 +71,6 @@ def save_settings(
|
|||||||
mixed_precision,
|
mixed_precision,
|
||||||
logger,
|
logger,
|
||||||
ch_8bit_adam,
|
ch_8bit_adam,
|
||||||
keep_last_n_checkpoints,
|
|
||||||
):
|
):
|
||||||
path_project = os.path.join(path_project_ckpts, project_name)
|
path_project = os.path.join(path_project_ckpts, project_name)
|
||||||
os.makedirs(path_project, exist_ok=True)
|
os.makedirs(path_project, exist_ok=True)
|
||||||
@@ -87,6 +87,7 @@ def save_settings(
|
|||||||
"epochs": epochs,
|
"epochs": epochs,
|
||||||
"num_warmup_updates": num_warmup_updates,
|
"num_warmup_updates": num_warmup_updates,
|
||||||
"save_per_updates": save_per_updates,
|
"save_per_updates": save_per_updates,
|
||||||
|
"keep_last_n_checkpoints": keep_last_n_checkpoints,
|
||||||
"last_per_updates": last_per_updates,
|
"last_per_updates": last_per_updates,
|
||||||
"finetune": finetune,
|
"finetune": finetune,
|
||||||
"file_checkpoint_train": file_checkpoint_train,
|
"file_checkpoint_train": file_checkpoint_train,
|
||||||
@@ -95,7 +96,6 @@ def save_settings(
|
|||||||
"mixed_precision": mixed_precision,
|
"mixed_precision": mixed_precision,
|
||||||
"logger": logger,
|
"logger": logger,
|
||||||
"bnb_optimizer": ch_8bit_adam,
|
"bnb_optimizer": ch_8bit_adam,
|
||||||
"keep_last_n_checkpoints": keep_last_n_checkpoints,
|
|
||||||
}
|
}
|
||||||
with open(file_setting, "w") as f:
|
with open(file_setting, "w") as f:
|
||||||
json.dump(settings, f, indent=4)
|
json.dump(settings, f, indent=4)
|
||||||
@@ -120,6 +120,7 @@ def load_settings(project_name):
|
|||||||
"epochs": 100,
|
"epochs": 100,
|
||||||
"num_warmup_updates": 2,
|
"num_warmup_updates": 2,
|
||||||
"save_per_updates": 300,
|
"save_per_updates": 300,
|
||||||
|
"keep_last_n_checkpoints": -1,
|
||||||
"last_per_updates": 100,
|
"last_per_updates": 100,
|
||||||
"finetune": True,
|
"finetune": True,
|
||||||
"file_checkpoint_train": "",
|
"file_checkpoint_train": "",
|
||||||
@@ -128,61 +129,20 @@ def load_settings(project_name):
|
|||||||
"mixed_precision": "none",
|
"mixed_precision": "none",
|
||||||
"logger": "wandb",
|
"logger": "wandb",
|
||||||
"bnb_optimizer": False,
|
"bnb_optimizer": False,
|
||||||
"keep_last_n_checkpoints": -1, # Default to keep all checkpoints
|
|
||||||
}
|
}
|
||||||
return (
|
else:
|
||||||
settings["exp_name"],
|
with open(file_setting, "r") as f:
|
||||||
settings["learning_rate"],
|
settings = json.load(f)
|
||||||
settings["batch_size_per_gpu"],
|
if "logger" not in settings:
|
||||||
settings["batch_size_type"],
|
settings["logger"] = "wandb"
|
||||||
settings["max_samples"],
|
if "bnb_optimizer" not in settings:
|
||||||
settings["grad_accumulation_steps"],
|
settings["bnb_optimizer"] = False
|
||||||
settings["max_grad_norm"],
|
if "keep_last_n_checkpoints" not in settings:
|
||||||
settings["epochs"],
|
settings["keep_last_n_checkpoints"] = -1 # default to keep all checkpoints
|
||||||
settings["num_warmup_updates"],
|
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
||||||
settings["save_per_updates"],
|
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
||||||
settings["last_per_updates"],
|
|
||||||
settings["finetune"],
|
|
||||||
settings["file_checkpoint_train"],
|
|
||||||
settings["tokenizer_type"],
|
|
||||||
settings["tokenizer_file"],
|
|
||||||
settings["mixed_precision"],
|
|
||||||
settings["logger"],
|
|
||||||
settings["bnb_optimizer"],
|
|
||||||
settings["keep_last_n_checkpoints"],
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(file_setting, "r") as f:
|
return settings
|
||||||
settings = json.load(f)
|
|
||||||
if "logger" not in settings:
|
|
||||||
settings["logger"] = "wandb"
|
|
||||||
if "bnb_optimizer" not in settings:
|
|
||||||
settings["bnb_optimizer"] = False
|
|
||||||
if "keep_last_n_checkpoints" not in settings:
|
|
||||||
settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
|
|
||||||
if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
|
|
||||||
settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
|
|
||||||
return (
|
|
||||||
settings["exp_name"],
|
|
||||||
settings["learning_rate"],
|
|
||||||
settings["batch_size_per_gpu"],
|
|
||||||
settings["batch_size_type"],
|
|
||||||
settings["max_samples"],
|
|
||||||
settings["grad_accumulation_steps"],
|
|
||||||
settings["max_grad_norm"],
|
|
||||||
settings["epochs"],
|
|
||||||
settings["num_warmup_updates"],
|
|
||||||
settings["save_per_updates"],
|
|
||||||
settings["last_per_updates"],
|
|
||||||
settings["finetune"],
|
|
||||||
settings["file_checkpoint_train"],
|
|
||||||
settings["tokenizer_type"],
|
|
||||||
settings["tokenizer_file"],
|
|
||||||
settings["mixed_precision"],
|
|
||||||
settings["logger"],
|
|
||||||
settings["bnb_optimizer"],
|
|
||||||
settings["keep_last_n_checkpoints"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
@@ -388,6 +348,7 @@ def start_training(
|
|||||||
epochs=11,
|
epochs=11,
|
||||||
num_warmup_updates=200,
|
num_warmup_updates=200,
|
||||||
save_per_updates=400,
|
save_per_updates=400,
|
||||||
|
keep_last_n_checkpoints=-1,
|
||||||
last_per_updates=800,
|
last_per_updates=800,
|
||||||
finetune=True,
|
finetune=True,
|
||||||
file_checkpoint_train="",
|
file_checkpoint_train="",
|
||||||
@@ -397,7 +358,6 @@ def start_training(
|
|||||||
stream=False,
|
stream=False,
|
||||||
logger="wandb",
|
logger="wandb",
|
||||||
ch_8bit_adam=False,
|
ch_8bit_adam=False,
|
||||||
keep_last_n_checkpoints=-1,
|
|
||||||
):
|
):
|
||||||
global training_process, tts_api, stop_signal
|
global training_process, tts_api, stop_signal
|
||||||
|
|
||||||
@@ -448,19 +408,19 @@ def start_training(
|
|||||||
fp16 = ""
|
fp16 = ""
|
||||||
|
|
||||||
cmd = (
|
cmd = (
|
||||||
f"accelerate launch {fp16} {file_train} --exp_name {exp_name} "
|
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
|
||||||
f"--learning_rate {learning_rate} "
|
f" --learning_rate {learning_rate}"
|
||||||
f"--batch_size_per_gpu {batch_size_per_gpu} "
|
f" --batch_size_per_gpu {batch_size_per_gpu}"
|
||||||
f"--batch_size_type {batch_size_type} "
|
f" --batch_size_type {batch_size_type}"
|
||||||
f"--max_samples {max_samples} "
|
f" --max_samples {max_samples}"
|
||||||
f"--grad_accumulation_steps {grad_accumulation_steps} "
|
f" --grad_accumulation_steps {grad_accumulation_steps}"
|
||||||
f"--max_grad_norm {max_grad_norm} "
|
f" --max_grad_norm {max_grad_norm}"
|
||||||
f"--epochs {epochs} "
|
f" --epochs {epochs}"
|
||||||
f"--num_warmup_updates {num_warmup_updates} "
|
f" --num_warmup_updates {num_warmup_updates}"
|
||||||
f"--save_per_updates {save_per_updates} "
|
f" --save_per_updates {save_per_updates}"
|
||||||
f"--last_per_updates {last_per_updates} "
|
f" --keep_last_n_checkpoints {keep_last_n_checkpoints}"
|
||||||
f"--dataset_name {dataset_name} "
|
f" --last_per_updates {last_per_updates}"
|
||||||
f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
|
f" --dataset_name {dataset_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if finetune:
|
if finetune:
|
||||||
@@ -493,6 +453,7 @@ def start_training(
|
|||||||
epochs,
|
epochs,
|
||||||
num_warmup_updates,
|
num_warmup_updates,
|
||||||
save_per_updates,
|
save_per_updates,
|
||||||
|
keep_last_n_checkpoints,
|
||||||
last_per_updates,
|
last_per_updates,
|
||||||
finetune,
|
finetune,
|
||||||
file_checkpoint_train,
|
file_checkpoint_train,
|
||||||
@@ -501,7 +462,6 @@ def start_training(
|
|||||||
mixed_precision,
|
mixed_precision,
|
||||||
logger,
|
logger,
|
||||||
ch_8bit_adam,
|
ch_8bit_adam,
|
||||||
keep_last_n_checkpoints,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1573,7 +1533,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
save_per_updates = gr.Number(label="Save per Updates", value=300)
|
||||||
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
|
||||||
keep_last_n_checkpoints = gr.Number(
|
keep_last_n_checkpoints = gr.Number(
|
||||||
label="Keep Last N Checkpoints",
|
label="Keep Last N Checkpoints",
|
||||||
value=-1,
|
value=-1,
|
||||||
@@ -1581,6 +1540,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|||||||
precision=0,
|
precision=0,
|
||||||
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
|
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
|
||||||
)
|
)
|
||||||
|
last_per_updates = gr.Number(label="Last per Updates", value=100)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
|
||||||
@@ -1590,46 +1550,27 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|||||||
stop_button = gr.Button("Stop Training", interactive=False)
|
stop_button = gr.Button("Stop Training", interactive=False)
|
||||||
|
|
||||||
if projects_selelect is not None:
|
if projects_selelect is not None:
|
||||||
(
|
settings = load_settings(projects_selelect)
|
||||||
exp_namev,
|
|
||||||
learning_ratev,
|
exp_name.value = settings["exp_name"]
|
||||||
batch_size_per_gpuv,
|
learning_rate.value = settings["learning_rate"]
|
||||||
batch_size_typev,
|
batch_size_per_gpu.value = settings["batch_size_per_gpu"]
|
||||||
max_samplesv,
|
batch_size_type.value = settings["batch_size_type"]
|
||||||
grad_accumulation_stepsv,
|
max_samples.value = settings["max_samples"]
|
||||||
max_grad_normv,
|
grad_accumulation_steps.value = settings["grad_accumulation_steps"]
|
||||||
epochsv,
|
max_grad_norm.value = settings["max_grad_norm"]
|
||||||
num_warmupv_updatesv,
|
epochs.value = settings["epochs"]
|
||||||
save_per_updatesv,
|
num_warmup_updates.value = settings["num_warmup_updates"]
|
||||||
last_per_updatesv,
|
save_per_updates.value = settings["save_per_updates"]
|
||||||
finetunev,
|
keep_last_n_checkpoints.value = settings["keep_last_n_checkpoints"]
|
||||||
file_checkpoint_trainv,
|
last_per_updates.value = settings["last_per_updates"]
|
||||||
tokenizer_typev,
|
ch_finetune.value = settings["finetune"]
|
||||||
tokenizer_filev,
|
file_checkpoint_train.value = settings["file_checkpoint_train"]
|
||||||
mixed_precisionv,
|
tokenizer_type.value = settings["tokenizer_type"]
|
||||||
cd_loggerv,
|
tokenizer_file.value = settings["tokenizer_file"]
|
||||||
ch_8bit_adamv,
|
mixed_precision.value = settings["mixed_precision"]
|
||||||
keep_last_n_checkpointsv,
|
cd_logger.value = settings["logger"]
|
||||||
) = load_settings(projects_selelect)
|
ch_8bit_adam.value = settings["bnb_optimizer"]
|
||||||
exp_name.value = exp_namev
|
|
||||||
learning_rate.value = learning_ratev
|
|
||||||
batch_size_per_gpu.value = batch_size_per_gpuv
|
|
||||||
batch_size_type.value = batch_size_typev
|
|
||||||
max_samples.value = max_samplesv
|
|
||||||
grad_accumulation_steps.value = grad_accumulation_stepsv
|
|
||||||
max_grad_norm.value = max_grad_normv
|
|
||||||
epochs.value = epochsv
|
|
||||||
num_warmup_updates.value = num_warmupv_updatesv
|
|
||||||
save_per_updates.value = save_per_updatesv
|
|
||||||
last_per_updates.value = last_per_updatesv
|
|
||||||
ch_finetune.value = finetunev
|
|
||||||
file_checkpoint_train.value = file_checkpoint_trainv
|
|
||||||
tokenizer_type.value = tokenizer_typev
|
|
||||||
tokenizer_file.value = tokenizer_filev
|
|
||||||
mixed_precision.value = mixed_precisionv
|
|
||||||
cd_logger.value = cd_loggerv
|
|
||||||
ch_8bit_adam.value = ch_8bit_adamv
|
|
||||||
keep_last_n_checkpoints.value = keep_last_n_checkpointsv
|
|
||||||
|
|
||||||
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
|
||||||
txt_info_train = gr.Text(label="Info", value="")
|
txt_info_train = gr.Text(label="Info", value="")
|
||||||
@@ -1680,6 +1621,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|||||||
epochs,
|
epochs,
|
||||||
num_warmup_updates,
|
num_warmup_updates,
|
||||||
save_per_updates,
|
save_per_updates,
|
||||||
|
keep_last_n_checkpoints,
|
||||||
last_per_updates,
|
last_per_updates,
|
||||||
ch_finetune,
|
ch_finetune,
|
||||||
file_checkpoint_train,
|
file_checkpoint_train,
|
||||||
@@ -1689,7 +1631,6 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|||||||
ch_stream,
|
ch_stream,
|
||||||
cd_logger,
|
cd_logger,
|
||||||
ch_8bit_adam,
|
ch_8bit_adam,
|
||||||
keep_last_n_checkpoints,
|
|
||||||
],
|
],
|
||||||
outputs=[txt_info_train, start_button, stop_button],
|
outputs=[txt_info_train, start_button, stop_button],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ def main(cfg):
|
|||||||
learning_rate=cfg.optim.learning_rate,
|
learning_rate=cfg.optim.learning_rate,
|
||||||
num_warmup_updates=cfg.optim.num_warmup_updates,
|
num_warmup_updates=cfg.optim.num_warmup_updates,
|
||||||
save_per_updates=cfg.ckpts.save_per_updates,
|
save_per_updates=cfg.ckpts.save_per_updates,
|
||||||
|
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1),
|
||||||
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
|
||||||
batch_size=cfg.datasets.batch_size_per_gpu,
|
batch_size=cfg.datasets.batch_size_per_gpu,
|
||||||
batch_size_type=cfg.datasets.batch_size_type,
|
batch_size_type=cfg.datasets.batch_size_type,
|
||||||
@@ -61,7 +62,6 @@ def main(cfg):
|
|||||||
mel_spec_type=mel_spec_type,
|
mel_spec_type=mel_spec_type,
|
||||||
is_local_vocoder=cfg.model.vocoder.is_local,
|
is_local_vocoder=cfg.model.vocoder.is_local,
|
||||||
local_vocoder_path=cfg.model.vocoder.local_path,
|
local_vocoder_path=cfg.model.vocoder.local_path,
|
||||||
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
|
||||||
|
|||||||
Reference in New Issue
Block a user