From 9a18bbe82327c55b007eea22ad24ba964314fa48 Mon Sep 17 00:00:00 2001 From: SWivid Date: Tue, 17 Dec 2024 07:57:33 +0800 Subject: [PATCH] v0.3.1 fix multi-style gradio bug; add features suggested #591 --- src/f5_tts/infer/infer_gradio.py | 100 +++++++++++++++---------------- 1 file changed, 47 insertions(+), 53 deletions(-) diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index 9bd9780..10e661e 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -120,6 +120,14 @@ def infer( speed=1, show_info=gr.Info, ): + if not ref_audio_orig: + gr.Warning("Please provide reference audio.") + return gr.update(), gr.update(), ref_text + + if not gen_text.strip(): + gr.Warning("Please enter text to generate.") + return gr.update(), gr.update(), ref_text + ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info) if model == "F5-TTS": @@ -240,7 +248,7 @@ with gr.Blocks() as app_tts: nfe_step=nfe_slider, speed=speed_slider, ) - return audio_out, spectrogram_path, gr.update(value=ref_text_out) + return audio_out, spectrogram_path, ref_text_out generate_btn.click( basic_tts, @@ -320,7 +328,7 @@ with gr.Blocks() as app_multistyle: ) # Regular speech type (mandatory) - with gr.Row(): + with gr.Row() as regular_row: with gr.Column(): regular_name = gr.Textbox(value="Regular", label="Speech Type Name") regular_insert = gr.Button("Insert Label", variant="secondary") @@ -329,12 +337,12 @@ with gr.Blocks() as app_multistyle: # Regular speech type (max 100) max_speech_types = 100 - speech_type_rows = [] # 99 - speech_type_names = [regular_name] # 100 - speech_type_audios = [regular_audio] # 100 - speech_type_ref_texts = [regular_ref_text] # 100 - speech_type_delete_btns = [] # 99 - speech_type_insert_btns = [regular_insert] # 100 + speech_type_rows = [regular_row] + speech_type_names = [regular_name] + speech_type_audios = [regular_audio] + speech_type_ref_texts = [regular_ref_text] + speech_type_delete_btns = [None] + speech_type_insert_btns = [regular_insert] # Additional speech types (99 more) for i in range(max_speech_types - 1): @@ -355,51 +363,32 @@ with gr.Blocks() as app_multistyle: # Button to add speech type add_speech_type_btn = gr.Button("Add Speech Type") - # Keep track of current number of speech types - speech_type_count = gr.State(value=1) + # Keep track of autoincrement of speech types, no roll back + speech_type_count = 1 # Function to add a speech type - def add_speech_type_fn(speech_type_count): + def add_speech_type_fn(): + row_updates = [gr.update() for _ in range(max_speech_types)] + global speech_type_count if speech_type_count < max_speech_types: + row_updates[speech_type_count] = gr.update(visible=True) speech_type_count += 1 - # Prepare updates for the rows - row_updates = [] - for i in range(1, max_speech_types): - if i < speech_type_count: - row_updates.append(gr.update(visible=True)) - else: - row_updates.append(gr.update()) else: - # Optionally, show a warning - row_updates = [gr.update() for _ in range(1, max_speech_types)] - return [speech_type_count] + row_updates + gr.Warning("Exhausted maximum number of speech types. Consider restart the app.") + return row_updates - add_speech_type_btn.click( - add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows - ) + add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows) # Function to delete a speech type - def make_delete_speech_type_fn(index): - def delete_speech_type_fn(speech_type_count): - # Prepare updates - row_updates = [] - - for i in range(1, max_speech_types): - if i == index: - row_updates.append(gr.update(visible=False)) - else: - row_updates.append(gr.update()) - - speech_type_count = max(1, speech_type_count) - - return [speech_type_count] + row_updates - - return delete_speech_type_fn + def delete_speech_type_fn(): + return gr.update(visible=False), None, None, None # Update delete button clicks - for i, delete_btn in enumerate(speech_type_delete_btns): - delete_fn = make_delete_speech_type_fn(i) - delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows) + for i in range(1, len(speech_type_delete_btns)): + speech_type_delete_btns[i].click( + delete_speech_type_fn, + outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]], + ) # Text input for the prompt gen_text_input_multistyle = gr.Textbox( @@ -413,7 +402,7 @@ with gr.Blocks() as app_multistyle: current_text = current_text or "" speech_type_name = speech_type_name or "None" updated_text = current_text + f"{{{speech_type_name}}} " - return gr.update(value=updated_text) + return updated_text return insert_speech_type_fn @@ -473,10 +462,14 @@ with gr.Blocks() as app_multistyle: if style in speech_types: current_style = style else: - # If style not available, default to Regular + gr.Warning(f"Type {style} is not available, will use Regular as default.") current_style = "Regular" - ref_audio = speech_types[current_style]["audio"] + try: + ref_audio = speech_types[current_style]["audio"] + except KeyError: + gr.Warning(f"Please provide reference audio for type {current_style}.") + return [None] + [speech_types[style]["ref_text"] for style in speech_types] ref_text = speech_types[current_style].get("ref_text", "") # Generate speech for this segment @@ -491,12 +484,10 @@ with gr.Blocks() as app_multistyle: # Concatenate all audio segments if generated_audio_segments: final_audio_data = np.concatenate(generated_audio_segments) - return [(sr, final_audio_data)] + [ - gr.update(value=speech_types[style]["ref_text"]) for style in speech_types - ] + return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types] else: gr.Warning("No audio generated.") - return [None] + [gr.update(value=speech_types[style]["ref_text"]) for style in speech_types] + return [None] + [speech_types[style]["ref_text"] for style in speech_types] generate_multistyle_btn.click( generate_multistyle_speech, @@ -514,7 +505,7 @@ with gr.Blocks() as app_multistyle: # Validation function to disable Generate button if speech types are missing def validate_speech_types(gen_text, regular_name, *args): - speech_type_names_list = args[:max_speech_types] + speech_type_names_list = args # Collect the speech types names speech_types_available = set() @@ -678,7 +669,7 @@ Have a conversation with an AI using your reference voice! speed=1.0, show_info=print, # show_info=print no pull to top when generating ) - return audio_result, gr.update(value=ref_text_out) + return audio_result, ref_text_out def clear_conversation(): """Reset the conversation""" @@ -828,7 +819,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip visible=False, ) custom_model_cfg = gr.Dropdown( - choices=[DEFAULT_TTS_MODEL_CFG[2]], + choices=[ + DEFAULT_TTS_MODEL_CFG[2], + json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)), + ], value=load_last_used_custom()[2], allow_custom_value=True, label="Config: in a dictionary form",