merging into one infer_batch_process function

This commit is contained in:
SWivid
2025-02-21 21:41:19 +08:00
parent 7ee55d773c
commit c3d415e47a
3 changed files with 78 additions and 172 deletions

View File

@@ -144,7 +144,14 @@ python src/f5_tts/socket_server.py
<details>
<summary>Then create client to communicate</summary>
```bash
# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio
```
``` python
# Create the socket_client.py
import socket
import asyncio
import pyaudio
@@ -165,7 +172,6 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
async def play_audio_stream():
nonlocal first_chunk_time
buffer = b""
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
@@ -204,7 +210,7 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components"
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))
```

View File

@@ -390,22 +390,24 @@ def infer_process(
print("\n")
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
return infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type=mel_spec_type,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
return next(
infer_batch_process(
(audio, sr),
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type=mel_spec_type,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
)
@@ -428,6 +430,8 @@ def infer_batch_process(
speed=1,
fix_duration=None,
device=None,
streaming=False,
chunk_size=2048,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
@@ -446,7 +450,12 @@ def infer_batch_process(
if len(ref_text[-1].encode("utf-8")) == 1:
ref_text = ref_text + " "
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
def process_batch(gen_text):
local_speed = speed
if len(gen_text.encode("utf-8")) < 10:
local_speed = 0.3
# Prepare the text
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
@@ -458,7 +467,7 @@ def infer_batch_process(
# Calculate duration
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
# inference
with torch.inference_mode():
@@ -484,191 +493,69 @@ def infer_batch_process(
# wav -> numpy
generated_wave = generated_wave.squeeze().cpu().numpy()
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec[0].cpu().numpy())
# Combine all generated waves with cross-fading
if cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):
prev_wave = final_wave
next_wave = generated_waves[i]
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
if cross_fade_samples <= 0:
# No overlap possible, concatenate
final_wave = np.concatenate([prev_wave, next_wave])
continue
# Overlapping parts
prev_overlap = prev_wave[-cross_fade_samples:]
next_overlap = next_wave[:cross_fade_samples]
# Fade out and fade in
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
# Cross-faded overlap
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
# Combine
new_wave = np.concatenate(
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
)
final_wave = new_wave
# Create a combined spectrogram
combined_spectrogram = np.concatenate(spectrograms, axis=1)
return final_wave, target_sample_rate, combined_spectrogram
# infer batch process for stream mode
def infer_batch_process_stream(
ref_audio,
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type="vocos",
progress=None,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
device=None,
streaming=False,
chunk_size=2048,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
audio = audio * target_rms / rms
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
audio = audio.to(device)
if len(ref_text[-1].encode("utf-8")) == 1:
ref_text = ref_text + " "
generated_waves = []
spectrograms = []
def process_batch(i, gen_text):
print(f"Generating audio for batch {i + 1}/{len(gen_text_batches)}: {gen_text}")
local_speed = speed
if len(gen_text) < 10:
local_speed = 0.3
text_list = [ref_text + gen_text]
final_text_list = convert_char_to_pinyin(text_list)
ref_audio_len = audio.shape[-1] // hop_length
if fix_duration is not None:
duration = int(fix_duration * target_sample_rate / hop_length)
else:
ref_text_len = len(ref_text.encode("utf-8"))
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
with torch.inference_mode():
generated, _ = model_obj.sample(
cond=audio,
text=final_text_list,
duration=duration,
steps=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
generated = generated.to(torch.float32)
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
print(f"Generated mel spectrogram shape: {generated_mel_spec.shape}")
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(generated_mel_spec)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(generated_mel_spec)
print(f"Generated wave shape before RMS adjustment: {generated_wave.shape}")
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
print(f"Generated wave shape after RMS adjustment: {generated_wave.shape}")
generated_wave = generated_wave.squeeze().cpu().numpy()
if streaming:
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
return generated_wave, generated_mel_spec[0].cpu().numpy()
else:
yield generated_wave, generated_mel_spec[0].cpu().numpy()
if streaming:
for i, gen_text in enumerate(progress.tqdm(gen_text_batches) if progress else gen_text_batches):
for chunk in process_batch(i, gen_text):
for gen_text in progress.tqdm(gen_text_batches) if progress else gen_text_batches:
for chunk in process_batch(gen_text):
yield chunk
else:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_batch, i, gen_text) for i, gen_text in enumerate(gen_text_batches)]
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
for future in progress.tqdm(futures) if progress else futures:
result = future.result()
if result:
generated_wave, generated_mel_spec = result
generated_wave, generated_mel_spec = next(result)
generated_waves.append(generated_wave)
spectrograms.append(generated_mel_spec)
if generated_waves:
if cross_fade_duration <= 0:
# Simply concatenate
final_wave = np.concatenate(generated_waves)
else:
# Combine all generated waves with cross-fading
final_wave = generated_waves[0]
for i in range(1, len(generated_waves)):
prev_wave = final_wave
next_wave = generated_waves[i]
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
if cross_fade_samples <= 0:
# No overlap possible, concatenate
final_wave = np.concatenate([prev_wave, next_wave])
continue
# Overlapping parts
prev_overlap = prev_wave[-cross_fade_samples:]
next_overlap = next_wave[:cross_fade_samples]
# Fade out and fade in
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
# Cross-faded overlap
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
# Combine
new_wave = np.concatenate(
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
)
final_wave = new_wave
# Create a combined spectrogram
combined_spectrogram = np.concatenate(spectrograms, axis=1)
yield final_wave, target_sample_rate, combined_spectrogram
else:
yield None, target_sample_rate, None

View File

@@ -1,20 +1,31 @@
import argparse
import gc
import logging
import numpy as np
import queue
import socket
import struct
import threading
import traceback
import wave
from importlib.resources import files
import torch
import torchaudio
import logging
import wave
import numpy as np
import argparse
import traceback
import gc
import threading
import queue
from nltk.tokenize import sent_tokenize
from infer.utils_infer import preprocess_ref_audio_text, load_vocoder, load_model, infer_batch_process_stream
from model.backbones.dit import DiT
from huggingface_hub import hf_hub_download
from importlib.resources import files
import nltk
from nltk.tokenize import sent_tokenize
from f5_tts.model.backbones.dit import DiT
from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_batch_process,
)
nltk.download("punkt_tab")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -103,12 +114,13 @@ class TTSStreamingProcessor:
def _warm_up(self):
logger.info("Warming up the model...")
gen_text = "Warm-up text for the model."
for _ in infer_batch_process_stream(
for _ in infer_batch_process(
(self.audio, self.sr),
self.ref_text,
[gen_text],
self.model,
self.vocoder,
progress=None,
device=self.device,
streaming=True,
):
@@ -118,12 +130,13 @@ class TTSStreamingProcessor:
def generate_stream(self, text, conn):
text_batches = sent_tokenize(text)
audio_stream = infer_batch_process_stream(
audio_stream = infer_batch_process(
(self.audio, self.sr),
self.ref_text,
text_batches,
self.model,
self.vocoder,
progress=None,
device=self.device,
streaming=True,
chunk_size=2048,