diff --git a/finetune-cli.py b/finetune-cli.py index 2f49e96..79ce9bb 100644 --- a/finetune-cli.py +++ b/finetune-cli.py @@ -28,6 +28,7 @@ def parse_args(): parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps') parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps') parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps') + parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune') return parser.parse_args() @@ -42,17 +43,21 @@ def main(): wandb_resume_id = None model_cls = DiT model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) - ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) + if args.finetune: + ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt")) elif args.exp_name == "E2TTS_Base": wandb_resume_id = None model_cls = UNetT model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) - ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + if args.finetune: + ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt")) + + if args.finetune: + path_ckpt = os.path.join("ckpts",args.dataset_name) + if os.path.isdir(path_ckpt)==False: + os.makedirs(path_ckpt,exist_ok=True) + shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path))) - path_ckpt = os.path.join("ckpts",args.dataset_name) - if os.path.isdir(path_ckpt)==False: - os.makedirs(path_ckpt,exist_ok=True) - shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path))) checkpoint_path=os.path.join("ckpts",args.dataset_name) # Use the dataset_name provided in the command line