mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
Update infer_gradio.py
Added "randomize seed" checkmark and option to specify seed showing last seed used and can manually enter the desired seed number.
This commit is contained in:
@@ -129,6 +129,7 @@ def infer(
|
||||
gen_text_file,
|
||||
model,
|
||||
remove_silence,
|
||||
seed,
|
||||
cross_fade_duration=0.15,
|
||||
nfe_step=32,
|
||||
speed=1,
|
||||
@@ -146,6 +147,12 @@ def infer(
|
||||
gr.Warning("Please enter text to generate or upload a text file.")
|
||||
return gr.update(), gr.update(), ref_text
|
||||
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
||||
|
||||
if model == DEFAULT_TTS_MODEL:
|
||||
@@ -206,7 +213,7 @@ with gr.Blocks() as app_tts:
|
||||
gr.Markdown("# Batched TTS")
|
||||
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
|
||||
with gr.Column(scale=2):
|
||||
with gr.Column(scale=1):
|
||||
gen_text_file = gr.File(label="Upload Text File to Generate (.txt)", file_types=[".txt"])
|
||||
generate_btn = gr.Button("Synthesize", variant="primary")
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
@@ -223,6 +230,18 @@ with gr.Blocks() as app_tts:
|
||||
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
|
||||
value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
randomize_seed = gr.Checkbox(
|
||||
label="Randomize Seed",
|
||||
value=True,
|
||||
info="Check to use a random seed for each generation. Uncheck to use the seed specified below.",
|
||||
)
|
||||
seed_input = gr.Textbox(
|
||||
label="Seed",
|
||||
value="0",
|
||||
placeholder="Enter a seed value",
|
||||
scale=1,
|
||||
)
|
||||
speed_slider = gr.Slider(
|
||||
label="Speed",
|
||||
minimum=0.3,
|
||||
@@ -271,10 +290,25 @@ with gr.Blocks() as app_tts:
|
||||
gen_text_input,
|
||||
gen_text_file,
|
||||
remove_silence,
|
||||
randomize_seed,
|
||||
seed_input,
|
||||
cross_fade_duration_slider,
|
||||
nfe_slider,
|
||||
speed_slider,
|
||||
):
|
||||
# Determine the seed to use
|
||||
if randomize_seed:
|
||||
seed = np.random.randint(0, 2**31)
|
||||
else:
|
||||
try:
|
||||
seed = int(seed_input)
|
||||
if seed < 0:
|
||||
gr.Warning("Seed must be a non-negative integer. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31)
|
||||
except ValueError:
|
||||
gr.Warning("Invalid seed value. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31)
|
||||
|
||||
audio_out, spectrogram_path, ref_text_out = infer(
|
||||
ref_audio_input,
|
||||
ref_text_input,
|
||||
@@ -283,11 +317,12 @@ with gr.Blocks() as app_tts:
|
||||
gen_text_file,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed,
|
||||
cross_fade_duration=cross_fade_duration_slider,
|
||||
nfe_step=nfe_slider,
|
||||
speed=speed_slider,
|
||||
)
|
||||
return audio_out, spectrogram_path, ref_text_out
|
||||
return audio_out, spectrogram_path, ref_text_out, str(seed)
|
||||
|
||||
gen_text_file.change(
|
||||
update_gen_text_from_file,
|
||||
@@ -310,11 +345,13 @@ with gr.Blocks() as app_tts:
|
||||
gen_text_input,
|
||||
gen_text_file,
|
||||
remove_silence,
|
||||
randomize_seed,
|
||||
seed_input,
|
||||
cross_fade_duration_slider,
|
||||
nfe_slider,
|
||||
speed_slider,
|
||||
],
|
||||
outputs=[audio_output, spectrogram_output, ref_text_input],
|
||||
outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
|
||||
)
|
||||
|
||||
|
||||
@@ -501,6 +538,18 @@ with gr.Blocks() as app_multistyle:
|
||||
label="Remove Silences",
|
||||
value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
randomize_seed_multistyle = gr.Checkbox(
|
||||
label="Randomize Seed",
|
||||
value=True,
|
||||
info="Check to use a random seed for each generation. Uncheck to use the seed specified below.",
|
||||
)
|
||||
seed_input_multistyle = gr.Textbox(
|
||||
label="Seed",
|
||||
value="0",
|
||||
placeholder="Enter a seed value",
|
||||
scale=1,
|
||||
)
|
||||
|
||||
# Generate button
|
||||
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
|
||||
@@ -524,8 +573,23 @@ with gr.Blocks() as app_multistyle:
|
||||
def generate_multistyle_speech(
|
||||
gen_text,
|
||||
gen_text_file,
|
||||
randomize_seed,
|
||||
seed_input,
|
||||
*args,
|
||||
):
|
||||
# Determine the seed to use
|
||||
if randomize_seed:
|
||||
seed = np.random.randint(0, 2**31)
|
||||
else:
|
||||
try:
|
||||
seed = int(seed_input)
|
||||
if seed < 0:
|
||||
gr.Warning("Seed must be a non-negative integer. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31)
|
||||
except ValueError:
|
||||
gr.Warning("Invalid seed value. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31)
|
||||
|
||||
speech_type_names_list = args[:max_speech_types]
|
||||
speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
|
||||
speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
|
||||
@@ -569,12 +633,12 @@ with gr.Blocks() as app_multistyle:
|
||||
ref_audio = speech_types[current_style]["audio"]
|
||||
except KeyError:
|
||||
gr.Warning(f"Please provide reference audio for type {current_style}.")
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types] + [str(seed)]
|
||||
ref_text = speech_types[current_style].get("ref_text", "")
|
||||
|
||||
# Generate speech for this segment
|
||||
audio_out, _, ref_text_out = infer(
|
||||
ref_audio, ref_text, None, text, None, tts_model_choice, remove_silence, 0, show_info=print
|
||||
ref_audio, ref_text, None, text, None, tts_model_choice, remove_silence, seed, 0, show_info=print
|
||||
) # show_info=print no pull to top when generating
|
||||
sr, audio_data = audio_out
|
||||
|
||||
@@ -584,20 +648,20 @@ with gr.Blocks() as app_multistyle:
|
||||
# Concatenate all audio segments
|
||||
if generated_audio_segments:
|
||||
final_audio_data = np.concatenate(generated_audio_segments)
|
||||
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types] + [str(seed)]
|
||||
else:
|
||||
gr.Warning("No audio generated.")
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types] + [str(seed)]
|
||||
|
||||
generate_multistyle_btn.click(
|
||||
generate_multistyle_speech,
|
||||
inputs=[gen_text_input_multistyle, gen_text_file_multistyle]
|
||||
inputs=[gen_text_input_multistyle, gen_text_file_multistyle, randomize_seed_multistyle, seed_input_multistyle]
|
||||
+ speech_type_names
|
||||
+ speech_type_audios
|
||||
+ speech_type_ref_texts
|
||||
+ speech_type_ref_text_files
|
||||
+ [remove_silence_multistyle],
|
||||
outputs=[audio_output_multistyle] + speech_type_ref_texts,
|
||||
outputs=[audio_output_multistyle] + speech_type_ref_texts + [seed_input_multistyle],
|
||||
)
|
||||
|
||||
# Validation function to disable Generate button if speech types are missing
|
||||
@@ -722,6 +786,18 @@ Have a conversation with an AI using your reference voice!
|
||||
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
lines=2,
|
||||
)
|
||||
with gr.Row():
|
||||
randomize_seed_chat = gr.Checkbox(
|
||||
label="Randomize Seed",
|
||||
value=True,
|
||||
info="Check to use a random seed for each generation. Uncheck to use the seed specified below.",
|
||||
)
|
||||
seed_input_chat = gr.Textbox(
|
||||
label="Seed",
|
||||
value="0",
|
||||
placeholder="Enter a seed value",
|
||||
scale=1,
|
||||
)
|
||||
|
||||
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
|
||||
|
||||
@@ -778,14 +854,27 @@ Have a conversation with an AI using your reference voice!
|
||||
return history, conv_state, "", None
|
||||
|
||||
@gpu_decorator
|
||||
def generate_audio_response(history, ref_audio, ref_text, ref_text_file, remove_silence):
|
||||
def generate_audio_response(history, ref_audio, ref_text, ref_text_file, remove_silence, randomize_seed, seed_input):
|
||||
"""Generate TTS audio for AI response"""
|
||||
if not history or not ref_audio:
|
||||
return None, ref_text
|
||||
return None, ref_text, seed_input
|
||||
|
||||
last_user_message, last_ai_response = history[-1]
|
||||
if not last_ai_response:
|
||||
return None, ref_text
|
||||
return None, ref_text, seed_input
|
||||
|
||||
# Determine the seed to use
|
||||
if randomize_seed:
|
||||
seed = np.random.randint(0, 2**31)
|
||||
else:
|
||||
try:
|
||||
seed = int(seed_input)
|
||||
if seed < 0:
|
||||
gr.Warning("Seed must be a non-negative integer. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31)
|
||||
except ValueError:
|
||||
gr.Warning("Invalid seed value. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31)
|
||||
|
||||
# Use text from file if provided, otherwise use direct text input
|
||||
ref_text = read_text_file(ref_text_file) or ref_text
|
||||
@@ -798,11 +887,12 @@ Have a conversation with an AI using your reference voice!
|
||||
None,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed,
|
||||
cross_fade_duration=0.15,
|
||||
speed=1.0,
|
||||
show_info=print, # show_info=print no pull to top when generating
|
||||
)
|
||||
return audio_result, ref_text_out
|
||||
return audio_result, ref_text_out, str(seed)
|
||||
|
||||
def clear_conversation():
|
||||
"""Reset the conversation"""
|
||||
@@ -843,8 +933,8 @@ Have a conversation with an AI using your reference voice!
|
||||
outputs=[chatbot_interface, conversation_state, text_input_chat, text_file_chat],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat, randomize_seed_chat, seed_input_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
|
||||
).then(
|
||||
lambda: None,
|
||||
None,
|
||||
@@ -858,8 +948,8 @@ Have a conversation with an AI using your reference voice!
|
||||
outputs=[chatbot_interface, conversation_state, text_input_chat, text_file_chat],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat, randomize_seed_chat, seed_input_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
|
||||
)
|
||||
|
||||
# Handle send button
|
||||
@@ -869,8 +959,8 @@ Have a conversation with an AI using your reference voice!
|
||||
outputs=[chatbot_interface, conversation_state, text_input_chat, text_file_chat],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat, randomize_seed_chat, seed_input_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
|
||||
)
|
||||
|
||||
# Handle clear button
|
||||
|
||||
Reference in New Issue
Block a user