Update infer_gradio.py

Added "randomize seed" checkmark and option to specify seed showing last seed used and can manually enter the desired seed number.
This commit is contained in:
petermg
2025-05-03 11:38:50 -07:00
committed by GitHub
parent ba1bf74215
commit 95976041f2

View File

@@ -129,6 +129,7 @@ def infer(
gen_text_file,
model,
remove_silence,
seed,
cross_fade_duration=0.15,
nfe_step=32,
speed=1,
@@ -146,6 +147,12 @@ def infer(
gr.Warning("Please enter text to generate or upload a text file.")
return gr.update(), gr.update(), ref_text
# Set random seed for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
if model == DEFAULT_TTS_MODEL:
@@ -206,7 +213,7 @@ with gr.Blocks() as app_tts:
gr.Markdown("# Batched TTS")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
with gr.Column(scale=2):
with gr.Column(scale=1):
gen_text_file = gr.File(label="Upload Text File to Generate (.txt)", file_types=[".txt"])
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
@@ -223,6 +230,18 @@ with gr.Blocks() as app_tts:
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
value=False,
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Check to use a random seed for each generation. Uncheck to use the seed specified below.",
)
seed_input = gr.Textbox(
label="Seed",
value="0",
placeholder="Enter a seed value",
scale=1,
)
speed_slider = gr.Slider(
label="Speed",
minimum=0.3,
@@ -271,10 +290,25 @@ with gr.Blocks() as app_tts:
gen_text_input,
gen_text_file,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
):
# Determine the seed to use
if randomize_seed:
seed = np.random.randint(0, 2**31)
else:
try:
seed = int(seed_input)
if seed < 0:
gr.Warning("Seed must be a non-negative integer. Using random seed instead.")
seed = np.random.randint(0, 2**31)
except ValueError:
gr.Warning("Invalid seed value. Using random seed instead.")
seed = np.random.randint(0, 2**31)
audio_out, spectrogram_path, ref_text_out = infer(
ref_audio_input,
ref_text_input,
@@ -283,11 +317,12 @@ with gr.Blocks() as app_tts:
gen_text_file,
tts_model_choice,
remove_silence,
seed=seed,
cross_fade_duration=cross_fade_duration_slider,
nfe_step=nfe_slider,
speed=speed_slider,
)
return audio_out, spectrogram_path, ref_text_out
return audio_out, spectrogram_path, ref_text_out, str(seed)
gen_text_file.change(
update_gen_text_from_file,
@@ -310,11 +345,13 @@ with gr.Blocks() as app_tts:
gen_text_input,
gen_text_file,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
],
outputs=[audio_output, spectrogram_output, ref_text_input],
outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
)
@@ -501,6 +538,18 @@ with gr.Blocks() as app_multistyle:
label="Remove Silences",
value=True,
)
with gr.Row():
randomize_seed_multistyle = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Check to use a random seed for each generation. Uncheck to use the seed specified below.",
)
seed_input_multistyle = gr.Textbox(
label="Seed",
value="0",
placeholder="Enter a seed value",
scale=1,
)
# Generate button
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
@@ -524,8 +573,23 @@ with gr.Blocks() as app_multistyle:
def generate_multistyle_speech(
gen_text,
gen_text_file,
randomize_seed,
seed_input,
*args,
):
# Determine the seed to use
if randomize_seed:
seed = np.random.randint(0, 2**31)
else:
try:
seed = int(seed_input)
if seed < 0:
gr.Warning("Seed must be a non-negative integer. Using random seed instead.")
seed = np.random.randint(0, 2**31)
except ValueError:
gr.Warning("Invalid seed value. Using random seed instead.")
seed = np.random.randint(0, 2**31)
speech_type_names_list = args[:max_speech_types]
speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
@@ -569,12 +633,12 @@ with gr.Blocks() as app_multistyle:
ref_audio = speech_types[current_style]["audio"]
except KeyError:
gr.Warning(f"Please provide reference audio for type {current_style}.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
return [None] + [speech_types[style]["ref_text"] for style in speech_types] + [str(seed)]
ref_text = speech_types[current_style].get("ref_text", "")
# Generate speech for this segment
audio_out, _, ref_text_out = infer(
ref_audio, ref_text, None, text, None, tts_model_choice, remove_silence, 0, show_info=print
ref_audio, ref_text, None, text, None, tts_model_choice, remove_silence, seed, 0, show_info=print
) # show_info=print no pull to top when generating
sr, audio_data = audio_out
@@ -584,20 +648,20 @@ with gr.Blocks() as app_multistyle:
# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types] + [str(seed)]
else:
gr.Warning("No audio generated.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
return [None] + [speech_types[style]["ref_text"] for style in speech_types] + [str(seed)]
generate_multistyle_btn.click(
generate_multistyle_speech,
inputs=[gen_text_input_multistyle, gen_text_file_multistyle]
inputs=[gen_text_input_multistyle, gen_text_file_multistyle, randomize_seed_multistyle, seed_input_multistyle]
+ speech_type_names
+ speech_type_audios
+ speech_type_ref_texts
+ speech_type_ref_text_files
+ [remove_silence_multistyle],
outputs=[audio_output_multistyle] + speech_type_ref_texts,
outputs=[audio_output_multistyle] + speech_type_ref_texts + [seed_input_multistyle],
)
# Validation function to disable Generate button if speech types are missing
@@ -722,6 +786,18 @@ Have a conversation with an AI using your reference voice!
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)
with gr.Row():
randomize_seed_chat = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Check to use a random seed for each generation. Uncheck to use the seed specified below.",
)
seed_input_chat = gr.Textbox(
label="Seed",
value="0",
placeholder="Enter a seed value",
scale=1,
)
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
@@ -778,14 +854,27 @@ Have a conversation with an AI using your reference voice!
return history, conv_state, "", None
@gpu_decorator
def generate_audio_response(history, ref_audio, ref_text, ref_text_file, remove_silence):
def generate_audio_response(history, ref_audio, ref_text, ref_text_file, remove_silence, randomize_seed, seed_input):
"""Generate TTS audio for AI response"""
if not history or not ref_audio:
return None, ref_text
return None, ref_text, seed_input
last_user_message, last_ai_response = history[-1]
if not last_ai_response:
return None, ref_text
return None, ref_text, seed_input
# Determine the seed to use
if randomize_seed:
seed = np.random.randint(0, 2**31)
else:
try:
seed = int(seed_input)
if seed < 0:
gr.Warning("Seed must be a non-negative integer. Using random seed instead.")
seed = np.random.randint(0, 2**31)
except ValueError:
gr.Warning("Invalid seed value. Using random seed instead.")
seed = np.random.randint(0, 2**31)
# Use text from file if provided, otherwise use direct text input
ref_text = read_text_file(ref_text_file) or ref_text
@@ -798,11 +887,12 @@ Have a conversation with an AI using your reference voice!
None,
tts_model_choice,
remove_silence,
seed=seed,
cross_fade_duration=0.15,
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result, ref_text_out
return audio_result, ref_text_out, str(seed)
def clear_conversation():
"""Reset the conversation"""
@@ -843,8 +933,8 @@ Have a conversation with an AI using your reference voice!
outputs=[chatbot_interface, conversation_state, text_input_chat, text_file_chat],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat, randomize_seed_chat, seed_input_chat],
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
).then(
lambda: None,
None,
@@ -858,8 +948,8 @@ Have a conversation with an AI using your reference voice!
outputs=[chatbot_interface, conversation_state, text_input_chat, text_file_chat],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat, randomize_seed_chat, seed_input_chat],
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
)
# Handle send button
@@ -869,8 +959,8 @@ Have a conversation with an AI using your reference voice!
outputs=[chatbot_interface, conversation_state, text_input_chat, text_file_chat],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, ref_text_file_chat, remove_silence_chat, randomize_seed_chat, seed_input_chat],
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
)
# Handle clear button