mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-14 05:57:41 -08:00
merging into one infer_batch_process function
This commit is contained in:
@@ -144,7 +144,14 @@ python src/f5_tts/socket_server.py
|
||||
<details>
|
||||
<summary>Then create client to communicate</summary>
|
||||
|
||||
```bash
|
||||
# If PyAudio not installed
|
||||
sudo apt-get install portaudio19-dev
|
||||
pip install pyaudio
|
||||
```
|
||||
|
||||
``` python
|
||||
# Create the socket_client.py
|
||||
import socket
|
||||
import asyncio
|
||||
import pyaudio
|
||||
@@ -165,7 +172,6 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
||||
|
||||
async def play_audio_stream():
|
||||
nonlocal first_chunk_time
|
||||
buffer = b""
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
||||
|
||||
@@ -204,7 +210,7 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components"
|
||||
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
||||
|
||||
asyncio.run(listen_to_F5TTS(text_to_send))
|
||||
```
|
||||
|
||||
@@ -390,22 +390,24 @@ def infer_process(
|
||||
print("\n")
|
||||
|
||||
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
||||
return infer_batch_process(
|
||||
(audio, sr),
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type=mel_spec_type,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
return next(
|
||||
infer_batch_process(
|
||||
(audio, sr),
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type=mel_spec_type,
|
||||
progress=progress,
|
||||
target_rms=target_rms,
|
||||
cross_fade_duration=cross_fade_duration,
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -428,6 +430,8 @@ def infer_batch_process(
|
||||
speed=1,
|
||||
fix_duration=None,
|
||||
device=None,
|
||||
streaming=False,
|
||||
chunk_size=2048,
|
||||
):
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
@@ -446,7 +450,12 @@ def infer_batch_process(
|
||||
|
||||
if len(ref_text[-1].encode("utf-8")) == 1:
|
||||
ref_text = ref_text + " "
|
||||
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
|
||||
|
||||
def process_batch(gen_text):
|
||||
local_speed = speed
|
||||
if len(gen_text.encode("utf-8")) < 10:
|
||||
local_speed = 0.3
|
||||
|
||||
# Prepare the text
|
||||
text_list = [ref_text + gen_text]
|
||||
final_text_list = convert_char_to_pinyin(text_list)
|
||||
@@ -458,7 +467,7 @@ def infer_batch_process(
|
||||
# Calculate duration
|
||||
ref_text_len = len(ref_text.encode("utf-8"))
|
||||
gen_text_len = len(gen_text.encode("utf-8"))
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
||||
|
||||
# inference
|
||||
with torch.inference_mode():
|
||||
@@ -484,191 +493,69 @@ def infer_batch_process(
|
||||
# wav -> numpy
|
||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||
|
||||
generated_waves.append(generated_wave)
|
||||
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
||||
|
||||
# Combine all generated waves with cross-fading
|
||||
if cross_fade_duration <= 0:
|
||||
# Simply concatenate
|
||||
final_wave = np.concatenate(generated_waves)
|
||||
else:
|
||||
final_wave = generated_waves[0]
|
||||
for i in range(1, len(generated_waves)):
|
||||
prev_wave = final_wave
|
||||
next_wave = generated_waves[i]
|
||||
|
||||
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
||||
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
||||
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
||||
|
||||
if cross_fade_samples <= 0:
|
||||
# No overlap possible, concatenate
|
||||
final_wave = np.concatenate([prev_wave, next_wave])
|
||||
continue
|
||||
|
||||
# Overlapping parts
|
||||
prev_overlap = prev_wave[-cross_fade_samples:]
|
||||
next_overlap = next_wave[:cross_fade_samples]
|
||||
|
||||
# Fade out and fade in
|
||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||
|
||||
# Cross-faded overlap
|
||||
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
||||
|
||||
# Combine
|
||||
new_wave = np.concatenate(
|
||||
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
||||
)
|
||||
|
||||
final_wave = new_wave
|
||||
|
||||
# Create a combined spectrogram
|
||||
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
||||
|
||||
return final_wave, target_sample_rate, combined_spectrogram
|
||||
|
||||
|
||||
# infer batch process for stream mode
|
||||
def infer_batch_process_stream(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type="vocos",
|
||||
progress=None,
|
||||
target_rms=0.1,
|
||||
cross_fade_duration=0.15,
|
||||
nfe_step=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
speed=1,
|
||||
fix_duration=None,
|
||||
device=None,
|
||||
streaming=False,
|
||||
chunk_size=2048,
|
||||
):
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < target_rms:
|
||||
audio = audio * target_rms / rms
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
audio = audio.to(device)
|
||||
|
||||
if len(ref_text[-1].encode("utf-8")) == 1:
|
||||
ref_text = ref_text + " "
|
||||
|
||||
generated_waves = []
|
||||
spectrograms = []
|
||||
|
||||
def process_batch(i, gen_text):
|
||||
print(f"Generating audio for batch {i + 1}/{len(gen_text_batches)}: {gen_text}")
|
||||
|
||||
local_speed = speed
|
||||
if len(gen_text) < 10:
|
||||
local_speed = 0.3
|
||||
|
||||
text_list = [ref_text + gen_text]
|
||||
final_text_list = convert_char_to_pinyin(text_list)
|
||||
|
||||
ref_audio_len = audio.shape[-1] // hop_length
|
||||
if fix_duration is not None:
|
||||
duration = int(fix_duration * target_sample_rate / hop_length)
|
||||
else:
|
||||
ref_text_len = len(ref_text.encode("utf-8"))
|
||||
gen_text_len = len(gen_text.encode("utf-8"))
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
||||
|
||||
with torch.inference_mode():
|
||||
generated, _ = model_obj.sample(
|
||||
cond=audio,
|
||||
text=final_text_list,
|
||||
duration=duration,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
)
|
||||
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
generated_mel_spec = generated.permute(0, 2, 1)
|
||||
|
||||
print(f"Generated mel spectrogram shape: {generated_mel_spec.shape}")
|
||||
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(generated_mel_spec)
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(generated_mel_spec)
|
||||
|
||||
print(f"Generated wave shape before RMS adjustment: {generated_wave.shape}")
|
||||
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * rms / target_rms
|
||||
|
||||
print(f"Generated wave shape after RMS adjustment: {generated_wave.shape}")
|
||||
|
||||
generated_wave = generated_wave.squeeze().cpu().numpy()
|
||||
|
||||
if streaming:
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
|
||||
return generated_wave, generated_mel_spec[0].cpu().numpy()
|
||||
else:
|
||||
yield generated_wave, generated_mel_spec[0].cpu().numpy()
|
||||
|
||||
if streaming:
|
||||
for i, gen_text in enumerate(progress.tqdm(gen_text_batches) if progress else gen_text_batches):
|
||||
for chunk in process_batch(i, gen_text):
|
||||
for gen_text in progress.tqdm(gen_text_batches) if progress else gen_text_batches:
|
||||
for chunk in process_batch(gen_text):
|
||||
yield chunk
|
||||
else:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(process_batch, i, gen_text) for i, gen_text in enumerate(gen_text_batches)]
|
||||
futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
|
||||
for future in progress.tqdm(futures) if progress else futures:
|
||||
result = future.result()
|
||||
if result:
|
||||
generated_wave, generated_mel_spec = result
|
||||
generated_wave, generated_mel_spec = next(result)
|
||||
generated_waves.append(generated_wave)
|
||||
spectrograms.append(generated_mel_spec)
|
||||
|
||||
if generated_waves:
|
||||
if cross_fade_duration <= 0:
|
||||
# Simply concatenate
|
||||
final_wave = np.concatenate(generated_waves)
|
||||
else:
|
||||
# Combine all generated waves with cross-fading
|
||||
final_wave = generated_waves[0]
|
||||
for i in range(1, len(generated_waves)):
|
||||
prev_wave = final_wave
|
||||
next_wave = generated_waves[i]
|
||||
|
||||
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
||||
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
||||
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
||||
|
||||
if cross_fade_samples <= 0:
|
||||
# No overlap possible, concatenate
|
||||
final_wave = np.concatenate([prev_wave, next_wave])
|
||||
continue
|
||||
|
||||
# Overlapping parts
|
||||
prev_overlap = prev_wave[-cross_fade_samples:]
|
||||
next_overlap = next_wave[:cross_fade_samples]
|
||||
|
||||
# Fade out and fade in
|
||||
fade_out = np.linspace(1, 0, cross_fade_samples)
|
||||
fade_in = np.linspace(0, 1, cross_fade_samples)
|
||||
|
||||
# Cross-faded overlap
|
||||
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
||||
|
||||
# Combine
|
||||
new_wave = np.concatenate(
|
||||
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
||||
)
|
||||
|
||||
final_wave = new_wave
|
||||
|
||||
# Create a combined spectrogram
|
||||
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
||||
|
||||
yield final_wave, target_sample_rate, combined_spectrogram
|
||||
|
||||
else:
|
||||
yield None, target_sample_rate, None
|
||||
|
||||
|
||||
@@ -1,20 +1,31 @@
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
import numpy as np
|
||||
import queue
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
import traceback
|
||||
import wave
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import logging
|
||||
import wave
|
||||
import numpy as np
|
||||
import argparse
|
||||
import traceback
|
||||
import gc
|
||||
import threading
|
||||
import queue
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from infer.utils_infer import preprocess_ref_audio_text, load_vocoder, load_model, infer_batch_process_stream
|
||||
from model.backbones.dit import DiT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from importlib.resources import files
|
||||
|
||||
import nltk
|
||||
from nltk.tokenize import sent_tokenize
|
||||
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from f5_tts.infer.utils_infer import (
|
||||
preprocess_ref_audio_text,
|
||||
load_vocoder,
|
||||
load_model,
|
||||
infer_batch_process,
|
||||
)
|
||||
|
||||
nltk.download("punkt_tab")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -103,12 +114,13 @@ class TTSStreamingProcessor:
|
||||
def _warm_up(self):
|
||||
logger.info("Warming up the model...")
|
||||
gen_text = "Warm-up text for the model."
|
||||
for _ in infer_batch_process_stream(
|
||||
for _ in infer_batch_process(
|
||||
(self.audio, self.sr),
|
||||
self.ref_text,
|
||||
[gen_text],
|
||||
self.model,
|
||||
self.vocoder,
|
||||
progress=None,
|
||||
device=self.device,
|
||||
streaming=True,
|
||||
):
|
||||
@@ -118,12 +130,13 @@ class TTSStreamingProcessor:
|
||||
def generate_stream(self, text, conn):
|
||||
text_batches = sent_tokenize(text)
|
||||
|
||||
audio_stream = infer_batch_process_stream(
|
||||
audio_stream = infer_batch_process(
|
||||
(self.audio, self.sr),
|
||||
self.ref_text,
|
||||
text_batches,
|
||||
self.model,
|
||||
self.vocoder,
|
||||
progress=None,
|
||||
device=self.device,
|
||||
streaming=True,
|
||||
chunk_size=2048,
|
||||
|
||||
Reference in New Issue
Block a user