diff --git a/src/f5_tts/infer/README.md b/src/f5_tts/infer/README.md index d49496f..d3bc877 100644 --- a/src/f5_tts/infer/README.md +++ b/src/f5_tts/infer/README.md @@ -144,7 +144,14 @@ python src/f5_tts/socket_server.py
Then create client to communicate +```bash +# If PyAudio not installed +sudo apt-get install portaudio19-dev +pip install pyaudio +``` + ``` python +# Create the socket_client.py import socket import asyncio import pyaudio @@ -165,7 +172,6 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): async def play_audio_stream(): nonlocal first_chunk_time - buffer = b"" p = pyaudio.PyAudio() stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048) @@ -204,7 +210,7 @@ async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): if __name__ == "__main__": - text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components" + text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components" asyncio.run(listen_to_F5TTS(text_to_send)) ``` diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 78a9f0c..0c5ca66 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -390,22 +390,24 @@ def infer_process( print("\n") show_info(f"Generating audio in {len(gen_text_batches)} batches...") - return infer_batch_process( - (audio, sr), - ref_text, - gen_text_batches, - model_obj, - vocoder, - mel_spec_type=mel_spec_type, - progress=progress, - target_rms=target_rms, - cross_fade_duration=cross_fade_duration, - nfe_step=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - speed=speed, - fix_duration=fix_duration, - device=device, + return next( + infer_batch_process( + (audio, sr), + ref_text, + gen_text_batches, + model_obj, + vocoder, + mel_spec_type=mel_spec_type, + progress=progress, + target_rms=target_rms, + cross_fade_duration=cross_fade_duration, + nfe_step=nfe_step, + cfg_strength=cfg_strength, + sway_sampling_coef=sway_sampling_coef, + speed=speed, + fix_duration=fix_duration, + device=device, + ) ) @@ -428,6 +430,8 @@ def infer_batch_process( speed=1, fix_duration=None, device=None, + streaming=False, + chunk_size=2048, ): audio, sr = ref_audio if audio.shape[0] > 1: @@ -446,7 +450,12 @@ def infer_batch_process( if len(ref_text[-1].encode("utf-8")) == 1: ref_text = ref_text + " " - for i, gen_text in enumerate(progress.tqdm(gen_text_batches)): + + def process_batch(gen_text): + local_speed = speed + if len(gen_text.encode("utf-8")) < 10: + local_speed = 0.3 + # Prepare the text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) @@ -458,7 +467,7 @@ def infer_batch_process( # Calculate duration ref_text_len = len(ref_text.encode("utf-8")) gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) + duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) # inference with torch.inference_mode(): @@ -484,191 +493,69 @@ def infer_batch_process( # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() - generated_waves.append(generated_wave) - spectrograms.append(generated_mel_spec[0].cpu().numpy()) - - # Combine all generated waves with cross-fading - if cross_fade_duration <= 0: - # Simply concatenate - final_wave = np.concatenate(generated_waves) - else: - final_wave = generated_waves[0] - for i in range(1, len(generated_waves)): - prev_wave = final_wave - next_wave = generated_waves[i] - - # Calculate cross-fade samples, ensuring it does not exceed wave lengths - cross_fade_samples = int(cross_fade_duration * target_sample_rate) - cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) - - if cross_fade_samples <= 0: - # No overlap possible, concatenate - final_wave = np.concatenate([prev_wave, next_wave]) - continue - - # Overlapping parts - prev_overlap = prev_wave[-cross_fade_samples:] - next_overlap = next_wave[:cross_fade_samples] - - # Fade out and fade in - fade_out = np.linspace(1, 0, cross_fade_samples) - fade_in = np.linspace(0, 1, cross_fade_samples) - - # Cross-faded overlap - cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in - - # Combine - new_wave = np.concatenate( - [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] - ) - - final_wave = new_wave - - # Create a combined spectrogram - combined_spectrogram = np.concatenate(spectrograms, axis=1) - - return final_wave, target_sample_rate, combined_spectrogram - - -# infer batch process for stream mode -def infer_batch_process_stream( - ref_audio, - ref_text, - gen_text_batches, - model_obj, - vocoder, - mel_spec_type="vocos", - progress=None, - target_rms=0.1, - cross_fade_duration=0.15, - nfe_step=32, - cfg_strength=2.0, - sway_sampling_coef=-1, - speed=1, - fix_duration=None, - device=None, - streaming=False, - chunk_size=2048, -): - audio, sr = ref_audio - if audio.shape[0] > 1: - audio = torch.mean(audio, dim=0, keepdim=True) - - rms = torch.sqrt(torch.mean(torch.square(audio))) - if rms < target_rms: - audio = audio * target_rms / rms - if sr != target_sample_rate: - resampler = torchaudio.transforms.Resample(sr, target_sample_rate) - audio = resampler(audio) - audio = audio.to(device) - - if len(ref_text[-1].encode("utf-8")) == 1: - ref_text = ref_text + " " - - generated_waves = [] - spectrograms = [] - - def process_batch(i, gen_text): - print(f"Generating audio for batch {i + 1}/{len(gen_text_batches)}: {gen_text}") - - local_speed = speed - if len(gen_text) < 10: - local_speed = 0.3 - - text_list = [ref_text + gen_text] - final_text_list = convert_char_to_pinyin(text_list) - - ref_audio_len = audio.shape[-1] // hop_length - if fix_duration is not None: - duration = int(fix_duration * target_sample_rate / hop_length) - else: - ref_text_len = len(ref_text.encode("utf-8")) - gen_text_len = len(gen_text.encode("utf-8")) - duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) - - with torch.inference_mode(): - generated, _ = model_obj.sample( - cond=audio, - text=final_text_list, - duration=duration, - steps=nfe_step, - cfg_strength=cfg_strength, - sway_sampling_coef=sway_sampling_coef, - ) - - generated = generated.to(torch.float32) - generated = generated[:, ref_audio_len:, :] - generated_mel_spec = generated.permute(0, 2, 1) - - print(f"Generated mel spectrogram shape: {generated_mel_spec.shape}") - - if mel_spec_type == "vocos": - generated_wave = vocoder.decode(generated_mel_spec) - elif mel_spec_type == "bigvgan": - generated_wave = vocoder(generated_mel_spec) - - print(f"Generated wave shape before RMS adjustment: {generated_wave.shape}") - - if rms < target_rms: - generated_wave = generated_wave * rms / target_rms - - print(f"Generated wave shape after RMS adjustment: {generated_wave.shape}") - - generated_wave = generated_wave.squeeze().cpu().numpy() - if streaming: for j in range(0, len(generated_wave), chunk_size): yield generated_wave[j : j + chunk_size], target_sample_rate - - return generated_wave, generated_mel_spec[0].cpu().numpy() + else: + yield generated_wave, generated_mel_spec[0].cpu().numpy() if streaming: - for i, gen_text in enumerate(progress.tqdm(gen_text_batches) if progress else gen_text_batches): - for chunk in process_batch(i, gen_text): + for gen_text in progress.tqdm(gen_text_batches) if progress else gen_text_batches: + for chunk in process_batch(gen_text): yield chunk else: with ThreadPoolExecutor() as executor: - futures = [executor.submit(process_batch, i, gen_text) for i, gen_text in enumerate(gen_text_batches)] + futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches] for future in progress.tqdm(futures) if progress else futures: result = future.result() if result: - generated_wave, generated_mel_spec = result + generated_wave, generated_mel_spec = next(result) generated_waves.append(generated_wave) spectrograms.append(generated_mel_spec) if generated_waves: if cross_fade_duration <= 0: + # Simply concatenate final_wave = np.concatenate(generated_waves) else: + # Combine all generated waves with cross-fading final_wave = generated_waves[0] for i in range(1, len(generated_waves)): prev_wave = final_wave next_wave = generated_waves[i] + # Calculate cross-fade samples, ensuring it does not exceed wave lengths cross_fade_samples = int(cross_fade_duration * target_sample_rate) cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) if cross_fade_samples <= 0: + # No overlap possible, concatenate final_wave = np.concatenate([prev_wave, next_wave]) continue + # Overlapping parts prev_overlap = prev_wave[-cross_fade_samples:] next_overlap = next_wave[:cross_fade_samples] + # Fade out and fade in fade_out = np.linspace(1, 0, cross_fade_samples) fade_in = np.linspace(0, 1, cross_fade_samples) + # Cross-faded overlap cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in + # Combine new_wave = np.concatenate( [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]] ) final_wave = new_wave + # Create a combined spectrogram combined_spectrogram = np.concatenate(spectrograms, axis=1) yield final_wave, target_sample_rate, combined_spectrogram + else: yield None, target_sample_rate, None diff --git a/src/f5_tts/socket_server.py b/src/f5_tts/socket_server.py index e941e6a..7175678 100644 --- a/src/f5_tts/socket_server.py +++ b/src/f5_tts/socket_server.py @@ -1,20 +1,31 @@ +import argparse +import gc +import logging +import numpy as np +import queue import socket import struct +import threading +import traceback +import wave +from importlib.resources import files + import torch import torchaudio -import logging -import wave -import numpy as np -import argparse -import traceback -import gc -import threading -import queue -from nltk.tokenize import sent_tokenize -from infer.utils_infer import preprocess_ref_audio_text, load_vocoder, load_model, infer_batch_process_stream -from model.backbones.dit import DiT from huggingface_hub import hf_hub_download -from importlib.resources import files + +import nltk +from nltk.tokenize import sent_tokenize + +from f5_tts.model.backbones.dit import DiT +from f5_tts.infer.utils_infer import ( + preprocess_ref_audio_text, + load_vocoder, + load_model, + infer_batch_process, +) + +nltk.download("punkt_tab") logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -103,12 +114,13 @@ class TTSStreamingProcessor: def _warm_up(self): logger.info("Warming up the model...") gen_text = "Warm-up text for the model." - for _ in infer_batch_process_stream( + for _ in infer_batch_process( (self.audio, self.sr), self.ref_text, [gen_text], self.model, self.vocoder, + progress=None, device=self.device, streaming=True, ): @@ -118,12 +130,13 @@ class TTSStreamingProcessor: def generate_stream(self, text, conn): text_batches = sent_tokenize(text) - audio_stream = infer_batch_process_stream( + audio_stream = infer_batch_process( (self.audio, self.sr), self.ref_text, text_batches, self.model, self.vocoder, + progress=None, device=self.device, streaming=True, chunk_size=2048,