diff --git a/pyproject.toml b/pyproject.toml index 40baf18..9f2da4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "f5-tts" -version = "0.5.3" +version = "0.6.0" description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching" readme = "README.md" license = {text = "MIT License"} @@ -25,6 +25,7 @@ dependencies = [ "jieba", "librosa", "matplotlib", + "nltk", "numpy<=1.26.4", "pydub", "pypinyin", @@ -40,7 +41,6 @@ dependencies = [ "vocos", "wandb", "x_transformers>=1.31.14", - "nltk" ] [project.optional-dependencies] diff --git a/src/f5_tts/infer/README.md b/src/f5_tts/infer/README.md index f723df1..d49496f 100644 --- a/src/f5_tts/infer/README.md +++ b/src/f5_tts/infer/README.md @@ -155,7 +155,8 @@ import time logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998): + +async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998): client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port))) @@ -164,13 +165,9 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998): async def play_audio_stream(): nonlocal first_chunk_time - buffer = b'' + buffer = b"" p = pyaudio.PyAudio() - stream = p.open(format=pyaudio.paFloat32, - channels=1, - rate=24000, - output=True, - frames_per_buffer=2048) + stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048) try: while True: @@ -195,7 +192,7 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998): logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds") try: - data_to_send = f"{character_name}|{text}".encode('utf-8') + data_to_send = f"{text}".encode("utf-8") await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send) await play_audio_stream() @@ -205,8 +202,9 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998): finally: client_socket.close() + if __name__ == "__main__": - text_to_send = "As a Reader assistant, I'm familiar with blockchain technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let’s break down the components" + text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components" asyncio.run(listen_to_F5TTS(text_to_send)) ``` diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index 9da5165..78a9f0c 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -529,25 +529,26 @@ def infer_batch_process( return final_wave, target_sample_rate, combined_spectrogram + # infer batch process for stream mode def infer_batch_process_stream( - ref_audio, - ref_text, - gen_text_batches, - model_obj, - vocoder, - mel_spec_type="vocos", - progress=None, - target_rms=0.1, - cross_fade_duration=0.15, - nfe_step=32, - cfg_strength=2.0, - sway_sampling_coef=-1, - speed=1, - fix_duration=None, - device=None, - streaming=False, - chunk_size=2048 + ref_audio, + ref_text, + gen_text_batches, + model_obj, + vocoder, + mel_spec_type="vocos", + progress=None, + target_rms=0.1, + cross_fade_duration=0.15, + nfe_step=32, + cfg_strength=2.0, + sway_sampling_coef=-1, + speed=1, + fix_duration=None, + device=None, + streaming=False, + chunk_size=2048, ): audio, sr = ref_audio if audio.shape[0] > 1: @@ -585,7 +586,6 @@ def infer_batch_process_stream( gen_text_len = len(gen_text.encode("utf-8")) duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed) - with torch.inference_mode(): generated, _ = model_obj.sample( cond=audio, @@ -618,7 +618,7 @@ def infer_batch_process_stream( if streaming: for j in range(0, len(generated_wave), chunk_size): - yield generated_wave[j:j + chunk_size], target_sample_rate + yield generated_wave[j : j + chunk_size], target_sample_rate return generated_wave, generated_mel_spec[0].cpu().numpy() @@ -629,7 +629,7 @@ def infer_batch_process_stream( else: with ThreadPoolExecutor() as executor: futures = [executor.submit(process_batch, i, gen_text) for i, gen_text in enumerate(gen_text_batches)] - for future in (progress.tqdm(futures) if progress else futures): + for future in progress.tqdm(futures) if progress else futures: result = future.result() if result: generated_wave, generated_mel_spec = result @@ -671,6 +671,8 @@ def infer_batch_process_stream( yield final_wave, target_sample_rate, combined_spectrogram else: yield None, target_sample_rate, None + + # remove silence from generated wav @@ -694,4 +696,4 @@ def save_spectrogram(spectrogram, path): plt.imshow(spectrogram, origin="lower", aspect="auto") plt.colorbar() plt.savefig(path) - plt.close() \ No newline at end of file + plt.close() diff --git a/src/f5_tts/socket_server.py b/src/f5_tts/socket_server.py index a17298e..e941e6a 100644 --- a/src/f5_tts/socket_server.py +++ b/src/f5_tts/socket_server.py @@ -6,7 +6,6 @@ import logging import wave import numpy as np import argparse -import time import traceback import gc import threading @@ -20,8 +19,10 @@ from importlib.resources import files logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class AudioFileWriterThread(threading.Thread): """Threaded file writer to avoid blocking the TTS streaming process.""" + def __init__(self, output_file, sampling_rate): super().__init__() self.output_file = output_file @@ -33,7 +34,7 @@ class AudioFileWriterThread(threading.Thread): def run(self): """Process queued audio data and write it to a file.""" logger.info("AudioFileWriterThread started.") - with wave.open(self.output_file, 'wb') as wf: + with wave.open(self.output_file, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(self.sampling_rate) @@ -102,12 +103,19 @@ class TTSStreamingProcessor: def _warm_up(self): logger.info("Warming up the model...") gen_text = "Warm-up text for the model." - for _ in infer_batch_process_stream((self.audio, self.sr), self.ref_text, [gen_text], self.model, self.vocoder, device=self.device, streaming=True): + for _ in infer_batch_process_stream( + (self.audio, self.sr), + self.ref_text, + [gen_text], + self.model, + self.vocoder, + device=self.device, + streaming=True, + ): pass logger.info("Warm-up completed.") def generate_stream(self, text, conn): - start_time = time.time() text_batches = sent_tokenize(text) audio_stream = infer_batch_process_stream( @@ -118,13 +126,13 @@ class TTSStreamingProcessor: self.vocoder, device=self.device, streaming=True, - chunk_size=2048 + chunk_size=2048, ) # Reset the file writer thread if self.file_writer_thread is not None: self.file_writer_thread.stop() - self.file_writer_thread = AudioFileWriterThread('output.wav', self.sampling_rate) + self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate) self.file_writer_thread.start() for audio_chunk, _ in audio_stream: @@ -132,7 +140,7 @@ class TTSStreamingProcessor: logger.info(f"Generated audio chunk of size: {len(audio_chunk)}") # Send audio chunk via socket - conn.sendall(struct.pack(f'{len(audio_chunk)}f', *audio_chunk)) + conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk)) # Write to file asynchronously self.file_writer_thread.add_chunk(audio_chunk) @@ -165,6 +173,7 @@ def handle_client(conn, processor): logger.error(f"Error handling client: {e}") traceback.print_exc() + def start_server(host, port, processor): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind((host, port)) @@ -175,6 +184,7 @@ def start_server(host, port, processor): logger.info(f"Connected by {addr}") handle_client(conn, processor) + if __name__ == "__main__": parser = argparse.ArgumentParser()