diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index 74d0cc2..bfd0499 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -6,6 +6,7 @@ from cached_path import cached_path from f5_tts.model import CFM, UNetT, DiT, Trainer from f5_tts.model.utils import get_tokenizer from f5_tts.model.dataset import load_dataset +from importlib.resources import files # -------------------------- Dataset Settings --------------------------- # @@ -63,6 +64,7 @@ def parse_args(): def main(): args = parse_args() + checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}")) # Model parameters based on experiment name if args.exp_name == "F5TTS_Base": @@ -85,12 +87,9 @@ def main(): ckpt_path = args.pretrain if args.finetune: - path_ckpt = os.path.join("ckpts", args.dataset_name) - if not os.path.isdir(path_ckpt): - os.makedirs(path_ckpt, exist_ok=True) - shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path))) - - checkpoint_path = os.path.join("ckpts", args.dataset_name) + if not os.path.isdir(checkpoint_path): + os.makedirs(checkpoint_path, exist_ok=True) + shutil.copy2(ckpt_path, os.path.join(checkpoint_path, os.path.basename(ckpt_path))) # Use the tokenizer and tokenizer_path provided in the command line arguments tokenizer = args.tokenizer