mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-29 14:15:18 -08:00
formatting
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "0.5.3"
|
||||
version = "0.6.0"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -25,6 +25,7 @@ dependencies = [
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"nltk",
|
||||
"numpy<=1.26.4",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
@@ -40,7 +41,6 @@ dependencies = [
|
||||
"vocos",
|
||||
"wandb",
|
||||
"x_transformers>=1.31.14",
|
||||
"nltk"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -155,7 +155,8 @@ import time
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
|
||||
|
||||
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
||||
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
|
||||
|
||||
@@ -164,13 +165,9 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
|
||||
|
||||
async def play_audio_stream():
|
||||
nonlocal first_chunk_time
|
||||
buffer = b''
|
||||
buffer = b""
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paFloat32,
|
||||
channels=1,
|
||||
rate=24000,
|
||||
output=True,
|
||||
frames_per_buffer=2048)
|
||||
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -195,7 +192,7 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
|
||||
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
|
||||
|
||||
try:
|
||||
data_to_send = f"{character_name}|{text}".encode('utf-8')
|
||||
data_to_send = f"{text}".encode("utf-8")
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
|
||||
await play_audio_stream()
|
||||
|
||||
@@ -205,8 +202,9 @@ async def listen_to_F5TTS(text, server_ip='localhost', server_port=9998):
|
||||
finally:
|
||||
client_socket.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_to_send = "As a Reader assistant, I'm familiar with blockchain technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let’s break down the components"
|
||||
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency.Let's break down the components"
|
||||
|
||||
asyncio.run(listen_to_F5TTS(text_to_send))
|
||||
```
|
||||
|
||||
@@ -529,25 +529,26 @@ def infer_batch_process(
|
||||
|
||||
return final_wave, target_sample_rate, combined_spectrogram
|
||||
|
||||
|
||||
# infer batch process for stream mode
|
||||
def infer_batch_process_stream(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type="vocos",
|
||||
progress=None,
|
||||
target_rms=0.1,
|
||||
cross_fade_duration=0.15,
|
||||
nfe_step=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
speed=1,
|
||||
fix_duration=None,
|
||||
device=None,
|
||||
streaming=False,
|
||||
chunk_size=2048
|
||||
ref_audio,
|
||||
ref_text,
|
||||
gen_text_batches,
|
||||
model_obj,
|
||||
vocoder,
|
||||
mel_spec_type="vocos",
|
||||
progress=None,
|
||||
target_rms=0.1,
|
||||
cross_fade_duration=0.15,
|
||||
nfe_step=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
speed=1,
|
||||
fix_duration=None,
|
||||
device=None,
|
||||
streaming=False,
|
||||
chunk_size=2048,
|
||||
):
|
||||
audio, sr = ref_audio
|
||||
if audio.shape[0] > 1:
|
||||
@@ -585,7 +586,6 @@ def infer_batch_process_stream(
|
||||
gen_text_len = len(gen_text.encode("utf-8"))
|
||||
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
generated, _ = model_obj.sample(
|
||||
cond=audio,
|
||||
@@ -618,7 +618,7 @@ def infer_batch_process_stream(
|
||||
|
||||
if streaming:
|
||||
for j in range(0, len(generated_wave), chunk_size):
|
||||
yield generated_wave[j:j + chunk_size], target_sample_rate
|
||||
yield generated_wave[j : j + chunk_size], target_sample_rate
|
||||
|
||||
return generated_wave, generated_mel_spec[0].cpu().numpy()
|
||||
|
||||
@@ -629,7 +629,7 @@ def infer_batch_process_stream(
|
||||
else:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(process_batch, i, gen_text) for i, gen_text in enumerate(gen_text_batches)]
|
||||
for future in (progress.tqdm(futures) if progress else futures):
|
||||
for future in progress.tqdm(futures) if progress else futures:
|
||||
result = future.result()
|
||||
if result:
|
||||
generated_wave, generated_mel_spec = result
|
||||
@@ -671,6 +671,8 @@ def infer_batch_process_stream(
|
||||
yield final_wave, target_sample_rate, combined_spectrogram
|
||||
else:
|
||||
yield None, target_sample_rate, None
|
||||
|
||||
|
||||
# remove silence from generated wav
|
||||
|
||||
|
||||
@@ -694,4 +696,4 @@ def save_spectrogram(spectrogram, path):
|
||||
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
||||
plt.colorbar()
|
||||
plt.savefig(path)
|
||||
plt.close()
|
||||
plt.close()
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
import wave
|
||||
import numpy as np
|
||||
import argparse
|
||||
import time
|
||||
import traceback
|
||||
import gc
|
||||
import threading
|
||||
@@ -20,8 +19,10 @@ from importlib.resources import files
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioFileWriterThread(threading.Thread):
|
||||
"""Threaded file writer to avoid blocking the TTS streaming process."""
|
||||
|
||||
def __init__(self, output_file, sampling_rate):
|
||||
super().__init__()
|
||||
self.output_file = output_file
|
||||
@@ -33,7 +34,7 @@ class AudioFileWriterThread(threading.Thread):
|
||||
def run(self):
|
||||
"""Process queued audio data and write it to a file."""
|
||||
logger.info("AudioFileWriterThread started.")
|
||||
with wave.open(self.output_file, 'wb') as wf:
|
||||
with wave.open(self.output_file, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(self.sampling_rate)
|
||||
@@ -102,12 +103,19 @@ class TTSStreamingProcessor:
|
||||
def _warm_up(self):
|
||||
logger.info("Warming up the model...")
|
||||
gen_text = "Warm-up text for the model."
|
||||
for _ in infer_batch_process_stream((self.audio, self.sr), self.ref_text, [gen_text], self.model, self.vocoder, device=self.device, streaming=True):
|
||||
for _ in infer_batch_process_stream(
|
||||
(self.audio, self.sr),
|
||||
self.ref_text,
|
||||
[gen_text],
|
||||
self.model,
|
||||
self.vocoder,
|
||||
device=self.device,
|
||||
streaming=True,
|
||||
):
|
||||
pass
|
||||
logger.info("Warm-up completed.")
|
||||
|
||||
def generate_stream(self, text, conn):
|
||||
start_time = time.time()
|
||||
text_batches = sent_tokenize(text)
|
||||
|
||||
audio_stream = infer_batch_process_stream(
|
||||
@@ -118,13 +126,13 @@ class TTSStreamingProcessor:
|
||||
self.vocoder,
|
||||
device=self.device,
|
||||
streaming=True,
|
||||
chunk_size=2048
|
||||
chunk_size=2048,
|
||||
)
|
||||
|
||||
# Reset the file writer thread
|
||||
if self.file_writer_thread is not None:
|
||||
self.file_writer_thread.stop()
|
||||
self.file_writer_thread = AudioFileWriterThread('output.wav', self.sampling_rate)
|
||||
self.file_writer_thread = AudioFileWriterThread("output.wav", self.sampling_rate)
|
||||
self.file_writer_thread.start()
|
||||
|
||||
for audio_chunk, _ in audio_stream:
|
||||
@@ -132,7 +140,7 @@ class TTSStreamingProcessor:
|
||||
logger.info(f"Generated audio chunk of size: {len(audio_chunk)}")
|
||||
|
||||
# Send audio chunk via socket
|
||||
conn.sendall(struct.pack(f'{len(audio_chunk)}f', *audio_chunk))
|
||||
conn.sendall(struct.pack(f"{len(audio_chunk)}f", *audio_chunk))
|
||||
|
||||
# Write to file asynchronously
|
||||
self.file_writer_thread.add_chunk(audio_chunk)
|
||||
@@ -165,6 +173,7 @@ def handle_client(conn, processor):
|
||||
logger.error(f"Error handling client: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def start_server(host, port, processor):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind((host, port))
|
||||
@@ -175,6 +184,7 @@ def start_server(host, port, processor):
|
||||
logger.info(f"Connected by {addr}")
|
||||
handle_client(conn, processor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user