fix space demo

This commit is contained in:
SWivid
2024-10-26 03:15:34 +08:00
parent 54d557789e
commit cc5ded275c

View File

@@ -454,13 +454,31 @@ Have a conversation with an AI using your reference voice!
"""
)
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
if not USING_SPACES:
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
chat_interface_container = gr.Column(visible=False)
chat_interface_container = gr.Column(visible=False)
@gpu_decorator
def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")
return gr.update(visible=False), gr.update(visible=True)
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
else:
chat_interface_container = gr.Column()
@gpu_decorator
def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
@@ -469,10 +487,6 @@ Have a conversation with an AI using your reference voice!
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")
return gr.update(visible=False), gr.update(visible=True)
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
with chat_interface_container:
with gr.Row():
with gr.Column():