diff --git a/finetune_gradio.py b/finetune_gradio.py index 062e68d..a61d95a 100644 --- a/finetune_gradio.py +++ b/finetune_gradio.py @@ -254,6 +254,7 @@ def start_training( del tts_api gc.collect() torch.cuda.empty_cache() + tts_api = None path_project = os.path.join(path_data, dataset_name + "_pinyin") @@ -698,7 +699,7 @@ def get_random_sample_infer(project_name): ) -def infer(project_name, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): +def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step): global last_checkpoint, last_device, tts_api if not os.path.isfile(file_checkpoint): @@ -917,7 +918,7 @@ with gr.Blocks() as app: check_button_infer.click( fn=infer, - inputs=[project_name, file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step], + inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step], outputs=[gen_audio], )