formatting

This commit is contained in:
SWivid
2025-02-21 17:00:51 +08:00
parent d68b1f304c
commit 7ee55d773c
4 changed files with 49 additions and 39 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.5.3"
version = "0.6.0"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -25,6 +25,7 @@ dependencies = [
"jieba",
"librosa",
"matplotlib",
"nltk",
"numpy<=1.26.4",
"pydub",
"pypinyin",
@@ -40,7 +41,6 @@ dependencies = [
"vocos",
"wandb",
"x_transformers>=1.31.14",
"nltk"
]
[project.optional-dependencies]

View File

@@ -155,7 +155,8 @@ import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
@@ -164,13 +165,9 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
async def play_audio_stream():
nonlocal first_chunk_time
buffer = b''
buffer = b""
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32,
channels=1,
rate=24000,
output=True,
frames_per_buffer=2048)
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
try:
while True:
@@ -195,7 +192,7 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
try:
data_to_send = f"{character_name}|{text}".encode('utf-8')
data_to_send = f"{text}".encode("utf-8")
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
await play_audio_stream()
@@ -205,8 +202,9 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
finally:
client_socket.close()
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with blockchain technology. which are key to its improved performance in terms of both training speed and inference efficiency.Lets break down the components"
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))
```

View File

@@ -529,25 +529,26 @@ def infer_batch_process(
return final_wave, target_sample_rate, combined_spectrogram
# infer batch process for stream mode
def infer_batch_process_stream(
ref_audio,
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type="vocos",
progress=None,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
device=None,
streaming=False,
chunk_size=2048
ref_audio,
ref_text,
gen_text_batches,
model_obj,
vocoder,
mel_spec_type="vocos",
progress=None,
target_rms=0.1,
cross_fade_duration=0.15,
nfe_step=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
speed=1,
fix_duration=None,
device=None,
streaming=False,
chunk_size=2048,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
@@ -585,7 +586,6 @@ def infer_batch_process_stream(
gen_text_len = len(gen_text.encode("utf-8"))
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
with torch.inference_mode():
generated, _ = model_obj.sample(
cond=audio,
@@ -618,7 +618,7 @@ def infer_batch_process_stream(
if streaming:
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j:j + chunk_size], target_sample_rate
yield generated_wave[j : j + chunk_size], target_sample_rate
return generated_wave, generated_mel_spec[0].cpu().numpy()
@@ -629,7 +629,7 @@ def infer_batch_process_stream(
else:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_batch, i, gen_text) for i, gen_text in enumerate(gen_text_batches)]
for future in (progress.tqdm(futures) if progress else futures):
for future in progress.tqdm(futures) if progress else futures:
result = future.result()
if result:
generated_wave, generated_mel_spec = result
@@ -671,6 +671,8 @@ def infer_batch_process_stream(
yield final_wave, target_sample_rate, combined_spectrogram
else:
yield None, target_sample_rate, None
# remove silence from generated wav
@@ -694,4 +696,4 @@ def save_spectrogram(spectrogram, path):
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()
plt.close()

View File

@@ -6,7 +6,6 @@ import logging
import wave
import numpy as np
import argparse
import time
import traceback
import gc
import threading
@@ -20,8 +19,10 @@ from importlib.resources import files
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AudioFileWriterThread(threading.Thread):
"""Threaded file writer to avoid blocking the TTS streaming process."""
def __init__(self, output_file, sampling_rate):
super().__init__()
self.output_file = output_file
@@ -33,7 +34,7 @@ class AudioFileWriterThread(threading.Thread):
def run(self):
"""Process queued audio data and write it to a file."""
logger.info("AudioFileWriterThread started.")
with wave.open(self.output_file, 'wb') as wf:
with wave.open(self.output_file, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.sampling_rate)
@@ -102,12 +103,19 @@ class TTSStreamingProcessor:
def _warm_up(self):
logger.info("Warming up the model...")
gen_text = "Warm-up text for the model."
for _ in infer_batch_process_stream((self.audio, self.sr), self.ref_text, [gen_text], self.model, self.vocoder, device=self.device, streaming=True):
for _ in infer_batch_process_stream(
(self.audio, self.sr),
self.ref_text,
[gen_text],
self.model,
self.vocoder,
device=self.device,
streaming=True,
):
pass
logger.info("Warm-up completed.")
def generate_stream(self, text, conn):
start_time = time.time()
text_batches = sent_tokenize(text)
audio_stream = infer_batch_process_stream(
@@ -118,13 +126,13 @@ class TTSStreamingProcessor:
self.vocoder,
device=self.device,
streaming=True,
chunk_size=2048
chunk_size=2048,
)
# Reset the file writer thread
if self.file_writer_thread is not None:
self.file_writer_thread.stop()
self.file_writer_thread = AudioFileWriterThread('output.wav', self.sampling_rate)
self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate)
self.file_writer_thread.start()
for audio_chunk, _ in audio_stream:
@@ -132,7 +140,7 @@ class TTSStreamingProcessor:
logger.info(f"Generated audio chunk of size: {len(audio_chunk)}")
# Send audio chunk via socket
conn.sendall(struct.pack(f'{len(audio_chunk)}f', *audio_chunk))
conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk))
# Write to file asynchronously
self.file_writer_thread.add_chunk(audio_chunk)
@@ -165,6 +173,7 @@ def handle_client(conn, processor):
logger.error(f"Error handling client: {e}")
traceback.print_exc()
def start_server(host, port, processor):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((host, port))
@@ -175,6 +184,7 @@ def start_server(host, port, processor):
logger.info(f"Connected by {addr}")
handle_client(conn, processor)
if __name__ == "__main__":
parser = argparse.ArgumentParser()