mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-04-28 08:43:06 -07:00
Several fixes for utils_infer.py; separate streaming and non-streaming func and add back parallelism
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||||
@@ -86,6 +87,8 @@ def chunk_text(text, max_chars=135):
|
|||||||
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
||||||
|
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
|
if not sentence:
|
||||||
|
continue
|
||||||
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
||||||
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
||||||
else:
|
else:
|
||||||
@@ -279,12 +282,12 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
|||||||
audio = audio[non_silent_start_idx:]
|
audio = audio[non_silent_start_idx:]
|
||||||
|
|
||||||
# Remove silence from the end
|
# Remove silence from the end
|
||||||
non_silent_end_duration = audio.duration_seconds
|
reversed_audio = audio.reverse()
|
||||||
for ms in reversed(audio):
|
non_silent_end_idx = silence.detect_leading_silence(reversed_audio, silence_threshold=silence_threshold)
|
||||||
if ms.dBFS > silence_threshold:
|
if non_silent_end_idx > 0:
|
||||||
break
|
trimmed_audio = audio[: len(audio) - non_silent_end_idx]
|
||||||
non_silent_end_duration -= 0.001
|
else:
|
||||||
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
|
trimmed_audio = audio
|
||||||
|
|
||||||
return trimmed_audio
|
return trimmed_audio
|
||||||
|
|
||||||
@@ -400,11 +403,16 @@ def infer_process(
|
|||||||
audio, sr = torchaudio.load(ref_audio)
|
audio, sr = torchaudio.load(ref_audio)
|
||||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
|
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
|
||||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||||
for i, gen_text in enumerate(gen_text_batches):
|
for i, gen_text_i in enumerate(gen_text_batches):
|
||||||
print(f"gen_text {i}", gen_text)
|
print(f"gen_text {i}", gen_text_i)
|
||||||
print("\n")
|
print("\n")
|
||||||
|
|
||||||
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
||||||
|
|
||||||
|
if not gen_text_batches:
|
||||||
|
show_info("No text batches to generate.")
|
||||||
|
return None, target_sample_rate, None
|
||||||
|
|
||||||
return next(
|
return next(
|
||||||
infer_batch_process(
|
infer_batch_process(
|
||||||
(audio, sr),
|
(audio, sr),
|
||||||
@@ -466,7 +474,7 @@ def infer_batch_process(
|
|||||||
if len(ref_text[-1].encode("utf-8")) == 1:
|
if len(ref_text[-1].encode("utf-8")) == 1:
|
||||||
ref_text = ref_text + " "
|
ref_text = ref_text + " "
|
||||||
|
|
||||||
def process_batch(gen_text):
|
def _infer_basic(gen_text):
|
||||||
local_speed = speed
|
local_speed = speed
|
||||||
if len(gen_text.encode("utf-8")) < 10:
|
if len(gen_text.encode("utf-8")) < 10:
|
||||||
local_speed = 0.3
|
local_speed = 0.3
|
||||||
@@ -509,23 +517,34 @@ def infer_batch_process(
|
|||||||
# wav -> numpy
|
# wav -> numpy
|
||||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||||
|
|
||||||
if streaming:
|
return generated_wave, generated
|
||||||
for j in range(0, len(generated_wave), chunk_size):
|
|
||||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
def infer_single_process(gen_text):
|
||||||
else:
|
generated_wave, generated = _infer_basic(gen_text)
|
||||||
generated_cpu = generated[0].cpu().numpy()
|
generated_cpu = generated[0].cpu().numpy()
|
||||||
del generated
|
del generated
|
||||||
yield generated_wave, generated_cpu
|
return generated_wave, generated_cpu
|
||||||
|
|
||||||
|
def infer_single_process_streaming(gen_text):
|
||||||
|
# for src/f5_tts/socket_server.py
|
||||||
|
generated_wave, generated = _infer_basic(gen_text)
|
||||||
|
del generated
|
||||||
|
for j in range(0, len(generated_wave), chunk_size):
|
||||||
|
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
||||||
for chunk in process_batch(gen_text):
|
for chunk in infer_single_process_streaming(gen_text):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
|
with ThreadPoolExecutor() as executor:
|
||||||
generated_wave, generated_mel_spec = next(process_batch(gen_text))
|
futures = [executor.submit(infer_single_process, gen_text) for gen_text in gen_text_batches]
|
||||||
generated_waves.append(generated_wave)
|
for future in progress.tqdm(futures) if progress is not None else futures:
|
||||||
spectrograms.append(generated_mel_spec)
|
result = future.result()
|
||||||
|
if result:
|
||||||
|
generated_wave, generated_mel_spec = result
|
||||||
|
generated_waves.append(generated_wave)
|
||||||
|
spectrograms.append(generated_mel_spec)
|
||||||
|
|
||||||
if generated_waves:
|
if generated_waves:
|
||||||
if cross_fade_duration <= 0:
|
if cross_fade_duration <= 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user