diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index 016640d..0637b38 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -1161,7 +1161,7 @@ def get_random_sample_infer(project_name): ) -def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema): +def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema): global last_checkpoint, last_device, tts_api, last_ema if not os.path.isfile(file_checkpoint): @@ -1182,7 +1182,11 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, us if last_ema != use_ema: last_ema = use_ema - tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test, use_ema=use_ema) + vocab_file = os.path.join(path_data, project, "vocab.txt") + + tts_api = F5TTS( + model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema + ) print("update >> ", device_test, file_checkpoint, use_ema) @@ -1630,7 +1634,7 @@ SOS : check the use_ema setting (True or False) for your model to see what works check_button_infer.click( fn=infer, - inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema], + inputs=[cm_project, cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema], outputs=[gen_audio, txt_info_gpu], )