mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-15 22:42:43 -08:00
fix path in finetune-cli working with new update (#270)
* fix path * change name
This commit is contained in:
@@ -6,6 +6,7 @@ from cached_path import cached_path
|
||||
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from importlib.resources import files
|
||||
|
||||
|
||||
# -------------------------- Dataset Settings --------------------------- #
|
||||
@@ -63,6 +64,7 @@ def parse_args():
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
||||
|
||||
# Model parameters based on experiment name
|
||||
if args.exp_name == "F5TTS_Base":
|
||||
@@ -85,12 +87,9 @@ def main():
|
||||
ckpt_path = args.pretrain
|
||||
|
||||
if args.finetune:
|
||||
path_ckpt = os.path.join("ckpts", args.dataset_name)
|
||||
if not os.path.isdir(path_ckpt):
|
||||
os.makedirs(path_ckpt, exist_ok=True)
|
||||
shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
|
||||
|
||||
checkpoint_path = os.path.join("ckpts", args.dataset_name)
|
||||
if not os.path.isdir(checkpoint_path):
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
shutil.copy2(ckpt_path, os.path.join(checkpoint_path, os.path.basename(ckpt_path)))
|
||||
|
||||
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
||||
tokenizer = args.tokenizer
|
||||
|
||||
Reference in New Issue
Block a user