mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-04-28 08:43:06 -07:00
Several fixes for utils_infer.py; separate streaming and non-streaming func and add back parallelism
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
@@ -86,6 +87,8 @@ def chunk_text(text, max_chars=135):
|
||||
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
||||
|
||||
for sentence in sentences:
|
||||
if not sentence:
|
||||
continue
|
||||
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||
else:
|
||||
@@ -279,12 +282,12 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
audio = audio[non_silent_start_idx:]
|
||||
|
||||
# Remove silence from the end
|
||||
non_silent_end_duration = audio.duration_seconds
|
||||
for ms in reversed(audio):
|
||||
if ms.dBFS > silence_threshold:
|
||||
break
|
||||
non_silent_end_duration -= 0.001
|
||||
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
|
||||
reversed_audio = audio.reverse()
|
||||
non_silent_end_idx = silence.detect_leading_silence(reversed_audio, silence_threshold=silence_threshold)
|
||||
if non_silent_end_idx > 0:
|
||||
trimmed_audio = audio[: len(audio) - non_silent_end_idx]
|
||||
else:
|
||||
trimmed_audio = audio
|
||||
|
||||
return trimmed_audio
|
||||
|
||||
@@ -400,11 +403,16 @@ def infer_process(
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text)
|
||||
for i, gen_text_i in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text_i)
|
||||
print("\n")
|
||||
|
||||
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
||||
|
||||
if not gen_text_batches:
|
||||
show_info("No text batches to generate.")
|
||||
return None, target_sample_rate, None
|
||||
|
||||
return next(
|
||||
infer_batch_process(
|
||||
(audio, sr),
|
||||
@@ -466,7 +474,7 @@ def infer_batch_process(
|
||||
if len(ref_text[-1].encode("utf-8")) == 1:
|
||||
ref_text = ref_text + " "
|
||||
|
||||
def process_batch(gen_text):
|
||||
def _infer_basic(gen_text):
|
||||
local_speed = speed
|
||||
if len(gen_text.encode("utf-8")) < 10:
|
||||
local_speed = 0.3
|
||||
@@ -509,23 +517,34 @@ def infer_batch_process(
|
||||
# wav -> numpy
|
||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||
|
||||
if streaming:
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
else:
|
||||
generated_cpu = generated[0].cpu().numpy()
|
||||
del generated
|
||||
yield generated_wave, generated_cpu
|
||||
return generated_wave, generated
|
||||
|
||||
def infer_single_process(gen_text):
|
||||
generated_wave, generated = _infer_basic(gen_text)
|
||||
generated_cpu = generated[0].cpu().numpy()
|
||||
del generated
|
||||
return generated_wave, generated_cpu
|
||||
|
||||
def infer_single_process_streaming(gen_text):
|
||||
# for src/f5_tts/socket_server.py
|
||||
generated_wave, generated = _infer_basic(gen_text)
|
||||
del generated
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
|
||||
if streaming:
|
||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
||||
for chunk in process_batch(gen_text):
|
||||
for chunk in infer_single_process_streaming(gen_text):
|
||||
yield chunk
|
||||
else:
|
||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
||||
generated_wave, generated_mel_spec = next(process_batch(gen_text))
|
||||
generated_waves.append(generated_wave)
|
||||
spectrograms.append(generated_mel_spec)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(infer_single_process, gen_text) for gen_text in gen_text_batches]
|
||||
for future in progress.tqdm(futures) if progress is not None else futures:
|
||||
result = future.result()
|
||||
if result:
|
||||
generated_wave, generated_mel_spec = result
|
||||
generated_waves.append(generated_wave)
|
||||
spectrograms.append(generated_mel_spec)
|
||||
|
||||
if generated_waves:
|
||||
if cross_fade_duration <= 0:
|
||||
|
||||
Reference in New Issue
Block a user