diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 0024aa8..9deac09 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -454,13 +454,31 @@ Have a conversation with an AI using your reference voice! """ ) - load_chat_model_btn = gr.Button("Load Chat Model", variant="primary") + if not USING_SPACES: + load_chat_model_btn = gr.Button("Load Chat Model", variant="primary") - chat_interface_container = gr.Column(visible=False) + chat_interface_container = gr.Column(visible=False) + + @gpu_decorator + def load_chat_model(): + global chat_model_state, chat_tokenizer_state + if chat_model_state is None: + show_info = gr.Info + show_info("Loading chat model...") + model_name = "Qwen/Qwen2.5-3B-Instruct" + chat_model_state = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="auto", device_map="auto" + ) + chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name) + show_info("Chat model loaded.") + + return gr.update(visible=False), gr.update(visible=True) + + load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container]) + + else: + chat_interface_container = gr.Column() - @gpu_decorator - def load_chat_model(): - global chat_model_state, chat_tokenizer_state if chat_model_state is None: show_info = gr.Info show_info("Loading chat model...") @@ -469,10 +487,6 @@ Have a conversation with an AI using your reference voice! chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name) show_info("Chat model loaded.") - return gr.update(visible=False), gr.update(visible=True) - - load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container]) - with chat_interface_container: with gr.Row(): with gr.Column():