mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
v0.3.2 add flags and default values to socket_server.py
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import argparse
|
||||
import gc
|
||||
import socket
|
||||
import struct
|
||||
import torch
|
||||
import torchaudio
|
||||
import traceback
|
||||
from importlib.resources import files
|
||||
from threading import Thread
|
||||
|
||||
|
||||
import gc
|
||||
import traceback
|
||||
|
||||
from cached_path import cached_path
|
||||
|
||||
from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
|
||||
from model.backbones.dit import DiT
|
||||
@@ -15,7 +16,9 @@ from model.backbones.dit import DiT
|
||||
|
||||
class TTSStreamingProcessor:
|
||||
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = device or (
|
||||
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Load the model using the provided checkpoint and vocab files
|
||||
self.model = load_model(
|
||||
@@ -137,23 +140,51 @@ def start_server(host, port, processor):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Load the model and vocoder using the provided files
|
||||
ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
|
||||
vocab_file = "" # Add vocab file path if needed
|
||||
ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
|
||||
ref_text = ""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", default=9998)
|
||||
|
||||
parser.add_argument(
|
||||
"--ckpt_file",
|
||||
default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
|
||||
help="Path to the model checkpoint file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab_file",
|
||||
default="",
|
||||
help="Path to the vocab file if customized",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ref_audio",
|
||||
default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
help="Reference audio to provide model with speaker characteristics",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_text",
|
||||
default="",
|
||||
help="Reference audio subtitle, leave empty to auto-transcribe",
|
||||
)
|
||||
|
||||
parser.add_argument("--device", default=None, help="Device to run the model on")
|
||||
parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Initialize the processor with the model and vocoder
|
||||
processor = TTSStreamingProcessor(
|
||||
ckpt_file=ckpt_file,
|
||||
vocab_file=vocab_file,
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
dtype=torch.float32,
|
||||
ckpt_file=args.ckpt_file,
|
||||
vocab_file=args.vocab_file,
|
||||
ref_audio=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
|
||||
# Start the server
|
||||
start_server("0.0.0.0", 9998, processor)
|
||||
start_server(args.host, args.port, processor)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user