Several fixes for utils_infer.py; separate streaming and non-streaming func and add back parallelism

This commit is contained in:
SWivid
2026-03-24 20:03:01 +08:00
parent 2414e3d492
commit 1a63dda3df

View File

@@ -2,6 +2,7 @@
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
import os import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
@@ -86,6 +87,8 @@ def chunk_text(text, max_chars=135):
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text) sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
for sentence in sentences: for sentence in sentences:
if not sentence:
continue
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars: if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
else: else:
@@ -279,12 +282,12 @@ def remove_silence_edges(audio, silence_threshold=-42):
audio = audio[non_silent_start_idx:] audio = audio[non_silent_start_idx:]
# Remove silence from the end # Remove silence from the end
non_silent_end_duration = audio.duration_seconds reversed_audio = audio.reverse()
for ms in reversed(audio): non_silent_end_idx = silence.detect_leading_silence(reversed_audio, silence_threshold=silence_threshold)
if ms.dBFS > silence_threshold: if non_silent_end_idx > 0:
break trimmed_audio = audio[: len(audio) - non_silent_end_idx]
non_silent_end_duration -= 0.001 else:
trimmed_audio = audio[: int(non_silent_end_duration * 1000)] trimmed_audio = audio
return trimmed_audio return trimmed_audio
@@ -400,11 +403,16 @@ def infer_process(
audio, sr = torchaudio.load(ref_audio) audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed) max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars) gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches): for i, gen_text_i in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text) print(f"gen_text {i}", gen_text_i)
print("\n") print("\n")
show_info(f"Generating audio in {len(gen_text_batches)} batches...") show_info(f"Generating audio in {len(gen_text_batches)} batches...")
if not gen_text_batches:
show_info("No text batches to generate.")
return None, target_sample_rate, None
return next( return next(
infer_batch_process( infer_batch_process(
(audio, sr), (audio, sr),
@@ -466,7 +474,7 @@ def infer_batch_process(
if len(ref_text[-1].encode("utf-8")) == 1: if len(ref_text[-1].encode("utf-8")) == 1:
ref_text = ref_text + " " ref_text = ref_text + " "
def process_batch(gen_text): def _infer_basic(gen_text):
local_speed = speed local_speed = speed
if len(gen_text.encode("utf-8")) < 10: if len(gen_text.encode("utf-8")) < 10:
local_speed = 0.3 local_speed = 0.3
@@ -509,23 +517,34 @@ def infer_batch_process(
# wav -> numpy # wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy() generated_wave = generated_wave.squeeze().cpu().numpy()
if streaming: return generated_wave, generated
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate def infer_single_process(gen_text):
else: generated_wave, generated = _infer_basic(gen_text)
generated_cpu = generated[0].cpu().numpy() generated_cpu = generated[0].cpu().numpy()
del generated del generated
yield generated_wave, generated_cpu return generated_wave, generated_cpu
def infer_single_process_streaming(gen_text):
# for src/f5_tts/socket_server.py
generated_wave, generated = _infer_basic(gen_text)
del generated
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
if streaming: if streaming:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
for chunk in process_batch(gen_text): for chunk in infer_single_process_streaming(gen_text):
yield chunk yield chunk
else: else:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches: with ThreadPoolExecutor() as executor:
generated_wave, generated_mel_spec = next(process_batch(gen_text)) futures = [executor.submit(infer_single_process, gen_text) for gen_text in gen_text_batches]
generated_waves.append(generated_wave) for future in progress.tqdm(futures) if progress is not None else futures:
spectrograms.append(generated_mel_spec) result = future.result()
if result:
generated_wave, generated_mel_spec = result
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec)
if generated_waves: if generated_waves:
if cross_fade_duration <= 0: if cross_fade_duration <= 0: