Several fixes for utils_infer.py; separate streaming and non-streaming func and add back parallelism

This commit is contained in:
SWivid
2026-03-24 20:03:01 +08:00
parent 2414e3d492
commit 1a63dda3df

View File

@@ -2,6 +2,7 @@
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
import os
import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
@@ -86,6 +87,8 @@ def chunk_text(text, max_chars=135):
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
for sentence in sentences:
if not sentence:
continue
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else:
@@ -279,12 +282,12 @@ def remove_silence_edges(audio, silence_threshold=-42):
audio = audio[non_silent_start_idx:]
# Remove silence from the end
non_silent_end_duration = audio.duration_seconds
for ms in reversed(audio):
if ms.dBFS > silence_threshold:
break
non_silent_end_duration -= 0.001
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
reversed_audio = audio.reverse()
non_silent_end_idx = silence.detect_leading_silence(reversed_audio, silence_threshold=silence_threshold)
if non_silent_end_idx > 0:
trimmed_audio = audio[: len(audio) - non_silent_end_idx]
else:
trimmed_audio = audio
return trimmed_audio
@@ -400,11 +403,16 @@ def infer_process(
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)
for i, gen_text_i in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text_i)
print("\n")
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
if not gen_text_batches:
show_info("No text batches to generate.")
return None, target_sample_rate, None
return next(
infer_batch_process(
(audio, sr),
@@ -466,7 +474,7 @@ def infer_batch_process(
if len(ref_text[-1].encode("utf-8")) == 1:
ref_text = ref_text + " "
def process_batch(gen_text):
def _infer_basic(gen_text):
local_speed = speed
if len(gen_text.encode("utf-8")) < 10:
local_speed = 0.3
@@ -509,23 +517,34 @@ def infer_batch_process(
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
if streaming:
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
else:
generated_cpu = generated[0].cpu().numpy()
del generated
yield generated_wave, generated_cpu
return generated_wave, generated
def infer_single_process(gen_text):
generated_wave, generated = _infer_basic(gen_text)
generated_cpu = generated[0].cpu().numpy()
del generated
return generated_wave, generated_cpu
def infer_single_process_streaming(gen_text):
# for src/f5_tts/socket_server.py
generated_wave, generated = _infer_basic(gen_text)
del generated
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
if streaming:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
for chunk in process_batch(gen_text):
for chunk in infer_single_process_streaming(gen_text):
yield chunk
else:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
generated_wave, generated_mel_spec = next(process_batch(gen_text))
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec)
with ThreadPoolExecutor() as executor:
futures = [executor.submit(infer_single_process, gen_text) for gen_text in gen_text_batches]
for future in progress.tqdm(futures) if progress is not None else futures:
result = future.result()
if result:
generated_wave, generated_mel_spec = result
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec)
if generated_waves:
if cross_fade_duration <= 0: