add finetune miss

This commit is contained in:
unknown
2024-10-17 22:43:18 +03:00
parent 68718023ea
commit 3f3743eda4

View File

@@ -28,6 +28,7 @@ def parse_args():
parser.add_argument('--num_warmup_updates', type=int, default=5, help='Warmup steps')
parser.add_argument('--save_per_updates', type=int, default=10, help='Save checkpoint every X steps')
parser.add_argument('--last_per_steps', type=int, default=10, help='Save last checkpoint every X steps')
parser.add_argument('--finetune', type=bool, default=True, help='Use Finetune')
return parser.parse_args()
@@ -42,17 +43,21 @@ def main():
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
if args.finetune:
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
if args.finetune:
ckpt_path = str(cached_path(f"hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
if args.finetune:
path_ckpt = os.path.join("ckpts",args.dataset_name)
if os.path.isdir(path_ckpt)==False:
os.makedirs(path_ckpt,exist_ok=True)
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
path_ckpt = os.path.join("ckpts",args.dataset_name)
if os.path.isdir(path_ckpt)==False:
os.makedirs(path_ckpt,exist_ok=True)
shutil.copy2(ckpt_path,os.path.join(path_ckpt,os.path.basename(ckpt_path)))
checkpoint_path=os.path.join("ckpts",args.dataset_name)
# Use the dataset_name provided in the command line