v0.3.2 add flags and default values to socket_server.py

This commit is contained in:
SWivid
2024-12-18 20:32:20 +08:00
parent 0e1f2fcd93
commit deaca8d24c
2 changed files with 49 additions and 18 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.3.1"
version = "0.3.2"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}

View File

@@ -1,13 +1,14 @@
import argparse
import gc
import socket
import struct
import torch
import torchaudio
import traceback
from importlib.resources import files
from threading import Thread
import gc
import traceback
from cached_path import cached_path
from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
from model.backbones.dit import DiT
@@ -15,7 +16,9 @@ from model.backbones.dit import DiT
class TTSStreamingProcessor:
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Load the model using the provided checkpoint and vocab files
self.model = load_model(
@@ -137,23 +140,51 @@ def start_server(host, port, processor):
if __name__ == "__main__":
try:
# Load the model and vocoder using the provided files
ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
vocab_file = "" # Add vocab file path if needed
ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
ref_text = ""
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", default=9998)
parser.add_argument(
"--ckpt_file",
default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
help="Path to the model checkpoint file",
)
parser.add_argument(
"--vocab_file",
default="",
help="Path to the vocab file if customized",
)
parser.add_argument(
"--ref_audio",
default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
help="Reference audio to provide model with speaker characteristics",
)
parser.add_argument(
"--ref_text",
default="",
help="Reference audio subtitle, leave empty to auto-transcribe",
)
parser.add_argument("--device", default=None, help="Device to run the model on")
parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
args = parser.parse_args()
try:
# Initialize the processor with the model and vocoder
processor = TTSStreamingProcessor(
ckpt_file=ckpt_file,
vocab_file=vocab_file,
ref_audio=ref_audio,
ref_text=ref_text,
dtype=torch.float32,
ckpt_file=args.ckpt_file,
vocab_file=args.vocab_file,
ref_audio=args.ref_audio,
ref_text=args.ref_text,
device=args.device,
dtype=args.dtype,
)
# Start the server
start_server("0.0.0.0", 9998, processor)
start_server(args.host, args.port, processor)
except KeyboardInterrupt:
gc.collect()