mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-16 15:04:53 -08:00
add finetune miss
This commit is contained in:
@@ -28,6 +28,7 @@ def parse_args():
|
||||
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
|
||||
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
|
||||
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
|
||||
parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -42,17 +43,21 @@ def main():
|
||||
wandb_resume_id = None
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
if args.finetune:
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
||||
elif args.exp_name == "E2TTS_Base":
|
||||
wandb_resume_id = None
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
if args.finetune:
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
||||
|
||||
if args.finetune:
|
||||
path_ckpt = os.path.join("ckpts",args.dataset_name)
|
||||
if os.path.isdir(path_ckpt)==False:
|
||||
os.makedirs(path_ckpt,exist_ok=True)
|
||||
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
|
||||
|
||||
path_ckpt = os.path.join("ckpts",args.dataset_name)
|
||||
if os.path.isdir(path_ckpt)==False:
|
||||
os.makedirs(path_ckpt,exist_ok=True)
|
||||
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
|
||||
checkpoint_path=os.path.join("ckpts",args.dataset_name)
|
||||
|
||||
# Use the dataset_name provided in the command line
|
||||
|
||||
Reference in New Issue
Block a user