mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-26 12:51:16 -08:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac79d0ec1e | ||
|
|
dad398c0c1 | ||
|
|
3d969bf78d | ||
|
|
7c741c05f9 | ||
|
|
6d1a1e886a | ||
|
|
b4efcd836a | ||
|
|
818b868fab | ||
|
|
e6fee5e9ba | ||
|
|
2de214c122 | ||
|
|
2999f642ce | ||
|
|
03cff73343 | ||
|
|
63c513840d | ||
|
|
3e6b6c0c0c | ||
|
|
f00ac4d06b | ||
|
|
b0658bfd24 | ||
|
|
0cae51d646 | ||
|
|
95976041f2 | ||
|
|
ba1bf74215 | ||
|
|
536c29ac57 | ||
|
|
c4c61b0110 | ||
|
|
5f80fec160 | ||
|
|
178cb8afe6 | ||
|
|
761c7ed938 | ||
|
|
13fd6f8e07 | ||
|
|
b2284b6cff | ||
|
|
4b4359bc39 | ||
|
|
fe5c562212 | ||
|
|
2374f8ec39 | ||
|
|
f4f10bff6c | ||
|
|
9771ec6a3a | ||
|
|
4b3cd13382 |
@@ -3,11 +3,14 @@ repos:
|
||||
# Ruff version.
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
name: ruff linter
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
name: ruff formatter
|
||||
- id: ruff
|
||||
name: ruff sorter
|
||||
args: [--select, I, --fix]
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
|
||||
28
README.md
28
README.md
@@ -107,6 +107,21 @@ docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,targ
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
|
||||
```
|
||||
|
||||
### Runtime
|
||||
|
||||
Deployment solution with Triton and TensorRT-LLM.
|
||||
|
||||
#### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF | Mode |
|
||||
|---------------------|----------------|-------------|--------|-----------------|
|
||||
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
||||
|
||||
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
|
||||
## Inference
|
||||
|
||||
@@ -179,19 +194,6 @@ f5-tts_infer-cli -c custom.toml
|
||||
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
||||
```
|
||||
|
||||
### 3. Runtime
|
||||
|
||||
Deployment solution with Triton and TensorRT-LLM.
|
||||
|
||||
#### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF |
|
||||
|-------|-------------|----------------|-------|
|
||||
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
||||
|
||||
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.0"
|
||||
version = "1.1.4"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
|
||||
@@ -6,5 +6,5 @@ target-version = "py310"
|
||||
dummy-variable-rgx = "^_.*$"
|
||||
|
||||
[lint.isort]
|
||||
force-single-line = true
|
||||
force-single-line = false
|
||||
lines-after-imports = 2
|
||||
|
||||
@@ -9,13 +9,13 @@ from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
transcribe,
|
||||
preprocess_ref_audio_text,
|
||||
infer_process,
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
transcribe,
|
||||
)
|
||||
from f5_tts.model.utils import seed_everything
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import argparse
|
||||
@@ -23,6 +24,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
|
||||
@@ -5,17 +5,16 @@ import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
from f5_tts.eval.utils_eval import (
|
||||
get_librispeech_test,
|
||||
run_asr_wer,
|
||||
run_sim,
|
||||
)
|
||||
|
||||
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
@@ -5,17 +5,16 @@ import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
from f5_tts.eval.utils_eval import (
|
||||
get_seed_tts_test,
|
||||
run_asr_wer,
|
||||
run_sim,
|
||||
)
|
||||
|
||||
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
@@ -14,20 +14,20 @@ from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
mel_spec_type,
|
||||
target_rms,
|
||||
cross_fade_duration,
|
||||
nfe_step,
|
||||
cfg_strength,
|
||||
sway_sampling_coef,
|
||||
speed,
|
||||
fix_duration,
|
||||
cross_fade_duration,
|
||||
device,
|
||||
fix_duration,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
mel_spec_type,
|
||||
nfe_step,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
speed,
|
||||
sway_sampling_coef,
|
||||
target_rms,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
import re
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from importlib.resources import files
|
||||
|
||||
import click
|
||||
@@ -17,6 +18,7 @@ import torchaudio
|
||||
from cached_path import cached_path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
import spaces
|
||||
|
||||
@@ -32,15 +34,15 @@ def gpu_decorator(func):
|
||||
return func
|
||||
|
||||
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.infer.utils_infer import (
|
||||
load_vocoder,
|
||||
load_model,
|
||||
preprocess_ref_audio_text,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT
|
||||
|
||||
|
||||
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
||||
@@ -78,6 +80,8 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
|
||||
vocab_path = str(cached_path(vocab_path))
|
||||
if model_cfg is None:
|
||||
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
||||
elif isinstance(model_cfg, str):
|
||||
model_cfg = json.loads(model_cfg)
|
||||
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
||||
|
||||
|
||||
@@ -90,7 +94,7 @@ chat_tokenizer_state = None
|
||||
|
||||
|
||||
@gpu_decorator
|
||||
def generate_response(messages, model, tokenizer):
|
||||
def chat_model_inference(messages, model, tokenizer):
|
||||
"""Generate response using Qwen"""
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
@@ -112,6 +116,17 @@ def generate_response(messages, model, tokenizer):
|
||||
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
@gpu_decorator
|
||||
def load_text_from_file(file):
|
||||
if file:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
text = f.read().strip()
|
||||
else:
|
||||
text = ""
|
||||
return gr.update(value=text)
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # NOTE. need to ensure params of infer() hashable
|
||||
@gpu_decorator
|
||||
def infer(
|
||||
ref_audio_orig,
|
||||
@@ -119,6 +134,7 @@ def infer(
|
||||
gen_text,
|
||||
model,
|
||||
remove_silence,
|
||||
seed,
|
||||
cross_fade_duration=0.15,
|
||||
nfe_step=32,
|
||||
speed=1,
|
||||
@@ -128,8 +144,15 @@ def infer(
|
||||
gr.Warning("Please provide reference audio.")
|
||||
return gr.update(), gr.update(), ref_text
|
||||
|
||||
# Set inference seed
|
||||
if seed < 0 or seed > 2**31 - 1:
|
||||
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31 - 1)
|
||||
torch.manual_seed(seed)
|
||||
used_seed = seed
|
||||
|
||||
if not gen_text.strip():
|
||||
gr.Warning("Please enter text to generate.")
|
||||
gr.Warning("Please enter text to generate or upload a text file.")
|
||||
return gr.update(), gr.update(), ref_text
|
||||
|
||||
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
||||
@@ -142,7 +165,7 @@ def infer(
|
||||
show_info("Loading E2-TTS model...")
|
||||
E2TTS_ema_model = load_e2tts()
|
||||
ema_model = E2TTS_ema_model
|
||||
elif isinstance(model, list) and model[0] == "Custom":
|
||||
elif isinstance(model, tuple) and model[0] == "Custom":
|
||||
assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
|
||||
global custom_ema_model, pre_custom_path
|
||||
if pre_custom_path != model[1]:
|
||||
@@ -177,7 +200,7 @@ def infer(
|
||||
spectrogram_path = tmp_spectrogram.name
|
||||
save_spectrogram(combined_spectrogram, spectrogram_path)
|
||||
|
||||
return (final_sample_rate, final_wave), spectrogram_path, ref_text
|
||||
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
|
||||
|
||||
|
||||
with gr.Blocks() as app_credits:
|
||||
@@ -191,19 +214,38 @@ with gr.Blocks() as app_credits:
|
||||
with gr.Blocks() as app_tts:
|
||||
gr.Markdown("# Batched TTS")
|
||||
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
|
||||
with gr.Row():
|
||||
gen_text_input = gr.Textbox(
|
||||
label="Text to Generate",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
scale=4,
|
||||
)
|
||||
gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
|
||||
generate_btn = gr.Button("Synthesize", variant="primary")
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
ref_text_input = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
|
||||
lines=2,
|
||||
)
|
||||
remove_silence = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
|
||||
value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
ref_text_input = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
|
||||
lines=2,
|
||||
scale=4,
|
||||
)
|
||||
ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
|
||||
with gr.Row():
|
||||
randomize_seed = gr.Checkbox(
|
||||
label="Randomize Seed",
|
||||
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
|
||||
value=True,
|
||||
scale=3,
|
||||
)
|
||||
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
|
||||
with gr.Column(scale=4):
|
||||
remove_silence = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
info="If undesired long silence(s) produced, turn on to automatically detect and crop.",
|
||||
value=False,
|
||||
)
|
||||
speed_slider = gr.Slider(
|
||||
label="Speed",
|
||||
minimum=0.3,
|
||||
@@ -238,21 +280,39 @@ with gr.Blocks() as app_tts:
|
||||
ref_text_input,
|
||||
gen_text_input,
|
||||
remove_silence,
|
||||
randomize_seed,
|
||||
seed_input,
|
||||
cross_fade_duration_slider,
|
||||
nfe_slider,
|
||||
speed_slider,
|
||||
):
|
||||
audio_out, spectrogram_path, ref_text_out = infer(
|
||||
if randomize_seed:
|
||||
seed_input = np.random.randint(0, 2**31 - 1)
|
||||
|
||||
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
|
||||
ref_audio_input,
|
||||
ref_text_input,
|
||||
gen_text_input,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed_input,
|
||||
cross_fade_duration=cross_fade_duration_slider,
|
||||
nfe_step=nfe_slider,
|
||||
speed=speed_slider,
|
||||
)
|
||||
return audio_out, spectrogram_path, ref_text_out
|
||||
return audio_out, spectrogram_path, ref_text_out, used_seed
|
||||
|
||||
gen_text_file.upload(
|
||||
load_text_from_file,
|
||||
inputs=[gen_text_file],
|
||||
outputs=[gen_text_input],
|
||||
)
|
||||
|
||||
ref_text_file.upload(
|
||||
load_text_from_file,
|
||||
inputs=[ref_text_file],
|
||||
outputs=[ref_text_input],
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
basic_tts,
|
||||
@@ -261,35 +321,46 @@ with gr.Blocks() as app_tts:
|
||||
ref_text_input,
|
||||
gen_text_input,
|
||||
remove_silence,
|
||||
randomize_seed,
|
||||
seed_input,
|
||||
cross_fade_duration_slider,
|
||||
nfe_slider,
|
||||
speed_slider,
|
||||
],
|
||||
outputs=[audio_output, spectrogram_output, ref_text_input],
|
||||
outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
|
||||
)
|
||||
|
||||
|
||||
def parse_speechtypes_text(gen_text):
|
||||
# Pattern to find {speechtype}
|
||||
pattern = r"\{(.*?)\}"
|
||||
# Pattern to find {str} or {"name": str, "seed": int, "speed": float}
|
||||
pattern = r"(\{.*?\})"
|
||||
|
||||
# Split the text by the pattern
|
||||
tokens = re.split(pattern, gen_text)
|
||||
|
||||
segments = []
|
||||
|
||||
current_style = "Regular"
|
||||
current_type_dict = {
|
||||
"name": "Regular",
|
||||
"seed": -1,
|
||||
"speed": 1.0,
|
||||
}
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if i % 2 == 0:
|
||||
# This is text
|
||||
text = tokens[i].strip()
|
||||
if text:
|
||||
segments.append({"style": current_style, "text": text})
|
||||
current_type_dict["text"] = text
|
||||
segments.append(current_type_dict)
|
||||
else:
|
||||
# This is style
|
||||
style = tokens[i].strip()
|
||||
current_style = style
|
||||
# This is type
|
||||
type_str = tokens[i].strip()
|
||||
try: # if type dict
|
||||
current_type_dict = json.loads(type_str)
|
||||
except json.decoder.JSONDecodeError:
|
||||
type_str = type_str[1:-1] # remove brace {}
|
||||
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
|
||||
|
||||
return segments
|
||||
|
||||
@@ -300,44 +371,55 @@ with gr.Blocks() as app_multistyle:
|
||||
"""
|
||||
# Multiple Speech-Type Generation
|
||||
|
||||
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
|
||||
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
"""
|
||||
**Example Input:**
|
||||
{Regular} Hello, I'd like to order a sandwich please.
|
||||
{Surprised} What do you mean you're out of bread?
|
||||
{Sad} I really wanted a sandwich though...
|
||||
{Angry} You know what, darn you and your little shop!
|
||||
{Whisper} I'll just go back home and cry now.
|
||||
{Shouting} Why me?!
|
||||
**Example Input:** <br>
|
||||
{Regular} Hello, I'd like to order a sandwich please. <br>
|
||||
{Surprised} What do you mean you're out of bread? <br>
|
||||
{Sad} I really wanted a sandwich though... <br>
|
||||
{Angry} You know what, darn you and your little shop! <br>
|
||||
{Whisper} I'll just go back home and cry now. <br>
|
||||
{Shouting} Why me?!
|
||||
"""
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
**Example Input 2:**
|
||||
{Speaker1_Happy} Hello, I'd like to order a sandwich please.
|
||||
{Speaker2_Regular} Sorry, we're out of bread.
|
||||
{Speaker1_Sad} I really wanted a sandwich though...
|
||||
{Speaker2_Whisper} I'll give you the last one I was hiding.
|
||||
**Example Input 2:** <br>
|
||||
{"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
|
||||
{"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
|
||||
{"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
|
||||
{"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
|
||||
"""
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
|
||||
'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
|
||||
)
|
||||
|
||||
# Regular speech type (mandatory)
|
||||
with gr.Row() as regular_row:
|
||||
with gr.Column():
|
||||
with gr.Row(variant="compact") as regular_row:
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
|
||||
regular_insert = gr.Button("Insert Label", variant="secondary")
|
||||
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
|
||||
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
|
||||
with gr.Column(scale=3):
|
||||
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
|
||||
with gr.Column(scale=3):
|
||||
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
|
||||
with gr.Row():
|
||||
regular_seed_slider = gr.Slider(
|
||||
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
|
||||
)
|
||||
regular_speed_slider = gr.Slider(
|
||||
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
|
||||
)
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
|
||||
|
||||
# Regular speech type (max 100)
|
||||
max_speech_types = 100
|
||||
@@ -345,25 +427,55 @@ with gr.Blocks() as app_multistyle:
|
||||
speech_type_names = [regular_name]
|
||||
speech_type_audios = [regular_audio]
|
||||
speech_type_ref_texts = [regular_ref_text]
|
||||
speech_type_ref_text_files = [regular_ref_text_file]
|
||||
speech_type_seeds = [regular_seed_slider]
|
||||
speech_type_speeds = [regular_speed_slider]
|
||||
speech_type_delete_btns = [None]
|
||||
speech_type_insert_btns = [regular_insert]
|
||||
|
||||
# Additional speech types (99 more)
|
||||
for i in range(max_speech_types - 1):
|
||||
with gr.Row(visible=False) as row:
|
||||
with gr.Column():
|
||||
with gr.Row(variant="compact", visible=False) as row:
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
name_input = gr.Textbox(label="Speech Type Name")
|
||||
delete_btn = gr.Button("Delete Type", variant="secondary")
|
||||
insert_btn = gr.Button("Insert Label", variant="secondary")
|
||||
audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
ref_text_input = gr.Textbox(label="Reference Text", lines=2)
|
||||
delete_btn = gr.Button("Delete Type", variant="stop")
|
||||
with gr.Column(scale=3):
|
||||
audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
with gr.Column(scale=3):
|
||||
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
|
||||
with gr.Row():
|
||||
seed_input = gr.Slider(
|
||||
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
|
||||
)
|
||||
speed_input = gr.Slider(
|
||||
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
|
||||
)
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
|
||||
speech_type_rows.append(row)
|
||||
speech_type_names.append(name_input)
|
||||
speech_type_audios.append(audio_input)
|
||||
speech_type_ref_texts.append(ref_text_input)
|
||||
speech_type_ref_text_files.append(ref_text_file_input)
|
||||
speech_type_seeds.append(seed_input)
|
||||
speech_type_speeds.append(speed_input)
|
||||
speech_type_delete_btns.append(delete_btn)
|
||||
speech_type_insert_btns.append(insert_btn)
|
||||
|
||||
# Global logic for all speech types
|
||||
for i in range(max_speech_types):
|
||||
speech_type_audios[i].clear(
|
||||
lambda: [None, None],
|
||||
None,
|
||||
[speech_type_ref_texts[i], speech_type_ref_text_files[i]],
|
||||
)
|
||||
speech_type_ref_text_files[i].upload(
|
||||
load_text_from_file,
|
||||
inputs=[speech_type_ref_text_files[i]],
|
||||
outputs=[speech_type_ref_texts[i]],
|
||||
)
|
||||
|
||||
# Button to add speech type
|
||||
add_speech_type_btn = gr.Button("Add Speech Type")
|
||||
|
||||
@@ -385,27 +497,44 @@ with gr.Blocks() as app_multistyle:
|
||||
|
||||
# Function to delete a speech type
|
||||
def delete_speech_type_fn():
|
||||
return gr.update(visible=False), None, None, None
|
||||
return gr.update(visible=False), None, None, None, None
|
||||
|
||||
# Update delete button clicks
|
||||
# Update delete button clicks and ref text file changes
|
||||
for i in range(1, len(speech_type_delete_btns)):
|
||||
speech_type_delete_btns[i].click(
|
||||
delete_speech_type_fn,
|
||||
outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
|
||||
outputs=[
|
||||
speech_type_rows[i],
|
||||
speech_type_names[i],
|
||||
speech_type_audios[i],
|
||||
speech_type_ref_texts[i],
|
||||
speech_type_ref_text_files[i],
|
||||
],
|
||||
)
|
||||
|
||||
# Text input for the prompt
|
||||
gen_text_input_multistyle = gr.Textbox(
|
||||
label="Text to Generate",
|
||||
lines=10,
|
||||
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
|
||||
)
|
||||
with gr.Row():
|
||||
gen_text_input_multistyle = gr.Textbox(
|
||||
label="Text to Generate",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
scale=4,
|
||||
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
|
||||
)
|
||||
gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
|
||||
|
||||
def make_insert_speech_type_fn(index):
|
||||
def insert_speech_type_fn(current_text, speech_type_name):
|
||||
def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
|
||||
current_text = current_text or ""
|
||||
speech_type_name = speech_type_name or "None"
|
||||
updated_text = current_text + f"{{{speech_type_name}}} "
|
||||
if not speech_type_name:
|
||||
gr.Warning("Please enter speech type name before insert.")
|
||||
return current_text
|
||||
speech_type_dict = {
|
||||
"name": speech_type_name,
|
||||
"seed": speech_type_seed,
|
||||
"speed": speech_type_speed,
|
||||
}
|
||||
updated_text = current_text + json.dumps(speech_type_dict) + " "
|
||||
return updated_text
|
||||
|
||||
return insert_speech_type_fn
|
||||
@@ -414,15 +543,24 @@ with gr.Blocks() as app_multistyle:
|
||||
insert_fn = make_insert_speech_type_fn(i)
|
||||
insert_btn.click(
|
||||
insert_fn,
|
||||
inputs=[gen_text_input_multistyle, speech_type_names[i]],
|
||||
inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
|
||||
outputs=gen_text_input_multistyle,
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
remove_silence_multistyle = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
value=True,
|
||||
)
|
||||
with gr.Accordion("Advanced Settings", open=True):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
show_cherrypick_multistyle = gr.Checkbox(
|
||||
label="Show Cherry-pick Interface",
|
||||
info="Turn on to show interface, picking seeds from previous generations.",
|
||||
value=False,
|
||||
)
|
||||
with gr.Column():
|
||||
remove_silence_multistyle = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
info="Turn on to automatically detect and crop long silences.",
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Generate button
|
||||
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
|
||||
@@ -430,6 +568,30 @@ with gr.Blocks() as app_multistyle:
|
||||
# Output audio
|
||||
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
|
||||
|
||||
# Used seed gallery
|
||||
cherrypick_interface_multistyle = gr.Textbox(
|
||||
label="Cherry-pick Interface",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
show_copy_button=True,
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
# Logic control to show/hide the cherrypick interface
|
||||
show_cherrypick_multistyle.change(
|
||||
lambda is_visible: gr.update(visible=is_visible),
|
||||
show_cherrypick_multistyle,
|
||||
cherrypick_interface_multistyle,
|
||||
)
|
||||
|
||||
# Function to load text to generate from file
|
||||
gen_text_file_multistyle.upload(
|
||||
load_text_from_file,
|
||||
inputs=[gen_text_file_multistyle],
|
||||
outputs=[gen_text_input_multistyle],
|
||||
)
|
||||
|
||||
@gpu_decorator
|
||||
def generate_multistyle_speech(
|
||||
gen_text,
|
||||
@@ -457,41 +619,60 @@ with gr.Blocks() as app_multistyle:
|
||||
|
||||
# For each segment, generate speech
|
||||
generated_audio_segments = []
|
||||
current_style = "Regular"
|
||||
current_type_name = "Regular"
|
||||
inference_meta_data = ""
|
||||
|
||||
for segment in segments:
|
||||
style = segment["style"]
|
||||
name = segment["name"]
|
||||
seed_input = segment["seed"]
|
||||
speed = segment["speed"]
|
||||
text = segment["text"]
|
||||
|
||||
if style in speech_types:
|
||||
current_style = style
|
||||
if name in speech_types:
|
||||
current_type_name = name
|
||||
else:
|
||||
gr.Warning(f"Type {style} is not available, will use Regular as default.")
|
||||
current_style = "Regular"
|
||||
gr.Warning(f"Type {name} is not available, will use Regular as default.")
|
||||
current_type_name = "Regular"
|
||||
|
||||
try:
|
||||
ref_audio = speech_types[current_style]["audio"]
|
||||
ref_audio = speech_types[current_type_name]["audio"]
|
||||
except KeyError:
|
||||
gr.Warning(f"Please provide reference audio for type {current_style}.")
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
ref_text = speech_types[current_style].get("ref_text", "")
|
||||
gr.Warning(f"Please provide reference audio for type {current_type_name}.")
|
||||
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
|
||||
ref_text = speech_types[current_type_name].get("ref_text", "")
|
||||
|
||||
# Generate speech for this segment
|
||||
audio_out, _, ref_text_out = infer(
|
||||
ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
|
||||
) # show_info=print no pull to top when generating
|
||||
if seed_input == -1:
|
||||
seed_input = np.random.randint(0, 2**31 - 1)
|
||||
|
||||
# Generate or retrieve speech for this segment
|
||||
audio_out, _, ref_text_out, used_seed = infer(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
text,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed_input,
|
||||
cross_fade_duration=0,
|
||||
speed=speed,
|
||||
show_info=print, # no pull to top when generating
|
||||
)
|
||||
sr, audio_data = audio_out
|
||||
|
||||
generated_audio_segments.append(audio_data)
|
||||
speech_types[current_style]["ref_text"] = ref_text_out
|
||||
speech_types[current_type_name]["ref_text"] = ref_text_out
|
||||
inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
|
||||
|
||||
# Concatenate all audio segments
|
||||
if generated_audio_segments:
|
||||
final_audio_data = np.concatenate(generated_audio_segments)
|
||||
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
return (
|
||||
[(sr, final_audio_data)]
|
||||
+ [speech_types[name]["ref_text"] for name in speech_types]
|
||||
+ [inference_meta_data]
|
||||
)
|
||||
else:
|
||||
gr.Warning("No audio generated.")
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
|
||||
|
||||
generate_multistyle_btn.click(
|
||||
generate_multistyle_speech,
|
||||
@@ -504,7 +685,7 @@ with gr.Blocks() as app_multistyle:
|
||||
+ [
|
||||
remove_silence_multistyle,
|
||||
],
|
||||
outputs=[audio_output_multistyle] + speech_type_ref_texts,
|
||||
outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
|
||||
)
|
||||
|
||||
# Validation function to disable Generate button if speech types are missing
|
||||
@@ -521,7 +702,7 @@ with gr.Blocks() as app_multistyle:
|
||||
|
||||
# Parse the gen_text to get the speech types used
|
||||
segments = parse_speechtypes_text(gen_text)
|
||||
speech_types_in_text = set(segment["style"] for segment in segments)
|
||||
speech_types_in_text = set(segment["name"] for segment in segments)
|
||||
|
||||
# Check if all speech types in text are available
|
||||
missing_speech_types = speech_types_in_text - speech_types_available
|
||||
@@ -544,10 +725,10 @@ with gr.Blocks() as app_chat:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Voice Chat
|
||||
Have a conversation with an AI using your reference voice!
|
||||
1. Upload a reference audio clip and optionally its transcript.
|
||||
Have a conversation with an AI using your reference voice!
|
||||
1. Upload a reference audio clip and optionally its transcript (via text or .txt file).
|
||||
2. Load the chat model.
|
||||
3. Record your message through your microphone.
|
||||
3. Record your message through your microphone or type it.
|
||||
4. The AI will respond using the reference voice.
|
||||
"""
|
||||
)
|
||||
@@ -603,22 +784,35 @@ Have a conversation with an AI using your reference voice!
|
||||
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
|
||||
with gr.Column():
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
with gr.Row():
|
||||
ref_text_chat = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Optional: Leave blank to auto-transcribe",
|
||||
lines=2,
|
||||
scale=3,
|
||||
)
|
||||
ref_text_file_chat = gr.File(
|
||||
label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
|
||||
)
|
||||
with gr.Row():
|
||||
randomize_seed_chat = gr.Checkbox(
|
||||
label="Randomize Seed",
|
||||
value=True,
|
||||
info="Uncheck to use the seed specified.",
|
||||
scale=3,
|
||||
)
|
||||
seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
|
||||
remove_silence_chat = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
value=True,
|
||||
)
|
||||
ref_text_chat = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Optional: Leave blank to auto-transcribe",
|
||||
lines=2,
|
||||
)
|
||||
system_prompt_chat = gr.Textbox(
|
||||
label="System Prompt",
|
||||
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
lines=2,
|
||||
)
|
||||
|
||||
chatbot_interface = gr.Chatbot(label="Conversation")
|
||||
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -635,132 +829,101 @@ Have a conversation with an AI using your reference voice!
|
||||
send_btn_chat = gr.Button("Send Message")
|
||||
clear_btn_chat = gr.Button("Clear Conversation")
|
||||
|
||||
conversation_state = gr.State(
|
||||
value=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Modify process_audio_input to use model and tokenizer from state
|
||||
# Modify process_audio_input to generate user input
|
||||
@gpu_decorator
|
||||
def process_audio_input(audio_path, text, history, conv_state):
|
||||
def process_audio_input(conv_state, audio_path, text):
|
||||
"""Handle audio or text input from user"""
|
||||
|
||||
if not audio_path and not text.strip():
|
||||
return history, conv_state, ""
|
||||
return conv_state
|
||||
|
||||
if audio_path:
|
||||
text = preprocess_ref_audio_text(audio_path, text)[1]
|
||||
|
||||
if not text.strip():
|
||||
return history, conv_state, ""
|
||||
return conv_state
|
||||
|
||||
conv_state.append({"role": "user", "content": text})
|
||||
history.append((text, None))
|
||||
return conv_state
|
||||
|
||||
response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
|
||||
# Use model and tokenizer from state to get text response
|
||||
@gpu_decorator
|
||||
def generate_text_response(conv_state, system_prompt):
|
||||
"""Generate text response from AI"""
|
||||
|
||||
system_prompt_state = [{"role": "system", "content": system_prompt}]
|
||||
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
|
||||
|
||||
conv_state.append({"role": "assistant", "content": response})
|
||||
history[-1] = (text, response)
|
||||
|
||||
return history, conv_state, ""
|
||||
return conv_state
|
||||
|
||||
@gpu_decorator
|
||||
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
|
||||
def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
|
||||
"""Generate TTS audio for AI response"""
|
||||
if not history or not ref_audio:
|
||||
return None
|
||||
if not conv_state or not ref_audio:
|
||||
return None, ref_text, seed_input
|
||||
|
||||
last_user_message, last_ai_response = history[-1]
|
||||
if not last_ai_response:
|
||||
return None
|
||||
last_ai_response = conv_state[-1]["content"]
|
||||
if not last_ai_response or conv_state[-1]["role"] != "assistant":
|
||||
return None, ref_text, seed_input
|
||||
|
||||
audio_result, _, ref_text_out = infer(
|
||||
if randomize_seed:
|
||||
seed_input = np.random.randint(0, 2**31 - 1)
|
||||
|
||||
audio_result, _, ref_text_out, used_seed = infer(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
last_ai_response,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed_input,
|
||||
cross_fade_duration=0.15,
|
||||
speed=1.0,
|
||||
show_info=print, # show_info=print no pull to top when generating
|
||||
)
|
||||
return audio_result, ref_text_out
|
||||
return audio_result, ref_text_out, used_seed
|
||||
|
||||
def clear_conversation():
|
||||
"""Reset the conversation"""
|
||||
return [], [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
}
|
||||
]
|
||||
return [], None
|
||||
|
||||
def update_system_prompt(new_prompt):
|
||||
"""Update the system prompt and reset the conversation"""
|
||||
new_conv_state = [{"role": "system", "content": new_prompt}]
|
||||
return [], new_conv_state
|
||||
|
||||
# Handle audio input
|
||||
audio_input_chat.stop_recording(
|
||||
process_audio_input,
|
||||
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
).then(
|
||||
lambda: None,
|
||||
None,
|
||||
audio_input_chat,
|
||||
ref_text_file_chat.upload(
|
||||
load_text_from_file,
|
||||
inputs=[ref_text_file_chat],
|
||||
outputs=[ref_text_chat],
|
||||
)
|
||||
|
||||
# Handle text input
|
||||
text_input_chat.submit(
|
||||
process_audio_input,
|
||||
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
).then(
|
||||
lambda: None,
|
||||
None,
|
||||
text_input_chat,
|
||||
)
|
||||
for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
|
||||
user_operation(
|
||||
process_audio_input,
|
||||
inputs=[chatbot_interface, audio_input_chat, text_input_chat],
|
||||
outputs=[chatbot_interface],
|
||||
).then(
|
||||
generate_text_response,
|
||||
inputs=[chatbot_interface, system_prompt_chat],
|
||||
outputs=[chatbot_interface],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[
|
||||
chatbot_interface,
|
||||
ref_audio_chat,
|
||||
ref_text_chat,
|
||||
remove_silence_chat,
|
||||
randomize_seed_chat,
|
||||
seed_input_chat,
|
||||
],
|
||||
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
|
||||
).then(
|
||||
lambda: [None, None],
|
||||
None,
|
||||
[audio_input_chat, text_input_chat],
|
||||
)
|
||||
|
||||
# Handle send button
|
||||
send_btn_chat.click(
|
||||
process_audio_input,
|
||||
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
).then(
|
||||
lambda: None,
|
||||
None,
|
||||
text_input_chat,
|
||||
)
|
||||
|
||||
# Handle clear button
|
||||
clear_btn_chat.click(
|
||||
clear_conversation,
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
)
|
||||
|
||||
# Handle system prompt change and reset conversation
|
||||
system_prompt_chat.change(
|
||||
update_system_prompt,
|
||||
inputs=system_prompt_chat,
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
)
|
||||
# Handle clear button or system prompt change and reset conversation
|
||||
for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
|
||||
user_operation(
|
||||
clear_conversation,
|
||||
outputs=[chatbot_interface, audio_output_chat],
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
@@ -798,7 +961,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
global tts_model_choice
|
||||
if new_choice == "Custom": # override in case webpage is refreshed
|
||||
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
|
||||
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
||||
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
|
||||
return (
|
||||
gr.update(visible=True, value=custom_ckpt_path),
|
||||
gr.update(visible=True, value=custom_vocab_path),
|
||||
@@ -810,7 +973,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
|
||||
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
|
||||
global tts_model_choice
|
||||
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
||||
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
|
||||
with open(last_used_custom, "w", encoding="utf-8") as f:
|
||||
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
|
||||
from importlib.resources import files
|
||||
@@ -7,6 +8,7 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
@@ -14,6 +16,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectro
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
@@ -55,7 +58,8 @@ win_length = model_cfg.model.mel_spec.win_length
|
||||
n_fft = model_cfg.model.mel_spec.n_fft
|
||||
|
||||
|
||||
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
output_dir = "tests"
|
||||
|
||||
|
||||
@@ -152,7 +156,7 @@ for part in parts_to_edit:
|
||||
dim=-1,
|
||||
)
|
||||
offset = end * target_sample_rate
|
||||
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
|
||||
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
||||
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
||||
audio = audio.to(device)
|
||||
edit_mask = edit_mask.to(device)
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
||||
|
||||
@@ -14,6 +15,7 @@ from importlib.resources import files
|
||||
|
||||
import matplotlib
|
||||
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
import matplotlib.pylab as plt
|
||||
@@ -27,10 +29,8 @@ from transformers import pipeline
|
||||
from vocos import Vocos
|
||||
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import (
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
_ref_audio_cache = {}
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from f5_tts.model.cfm import CFM
|
||||
|
||||
from f5_tts.model.backbones.unett import UNetT
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from f5_tts.model.backbones.mmdit import MMDiT
|
||||
|
||||
from f5_tts.model.backbones.unett import UNetT
|
||||
from f5_tts.model.cfm import CFM
|
||||
from f5_tts.model.trainer import Trainer
|
||||
|
||||
|
||||
|
||||
@@ -10,19 +10,18 @@ d - dimension
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
AdaLayerNorm_Final,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNorm_Final,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,16 +11,15 @@ from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
AdaLayerNorm_Final,
|
||||
ConvPositionEmbedding,
|
||||
MMDiTBlock,
|
||||
AdaLayerNorm_Final,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,24 +8,24 @@ d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from x_transformers import RMSNorm
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
FeedForward,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from f5_tts.model import CFM
|
||||
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
||||
from f5_tts.model.utils import default, exists
|
||||
|
||||
|
||||
# trainer
|
||||
|
||||
|
||||
|
||||
@@ -5,11 +5,10 @@ import random
|
||||
from collections import defaultdict
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import jieba
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
import torch
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
# seed everything
|
||||
|
||||
@@ -30,18 +30,40 @@ bash run.sh 0 4 F5TTS_Base
|
||||
python3 client_http.py
|
||||
```
|
||||
|
||||
### Benchmark using Dataset
|
||||
### Benchmark using Client-Server Mode
|
||||
```sh
|
||||
num_task=2
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
||||
```
|
||||
|
||||
### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
||||
### Benchmark using Offline TRT-LLM Mode
|
||||
```sh
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
```
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF |
|
||||
|-------|-------------|----------------|-------|
|
||||
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
||||
### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF | Mode |
|
||||
|---------------------|----------------|-------------|--------|-----------------|
|
||||
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
||||
|
||||
### Credits
|
||||
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
|
||||
560
src/f5_tts/runtime/triton_trtllm/benchmark.py
Normal file
560
src/f5_tts/runtime/triton_trtllm/benchmark.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
||||
""" Example Usage
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import datasets
|
||||
import jieba
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from f5_tts_trtllm import F5TTS
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session, TensorInfo
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from vocos import Vocos
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="extract speech code")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="huggingface dataset split name",
|
||||
)
|
||||
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
|
||||
parser.add_argument(
|
||||
"--vocab-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="vocab file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="model path, to load text embedding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tllm-model-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="tllm model dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
required=True,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
|
||||
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
default="vocos",
|
||||
type=str,
|
||||
help="vocoder name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocoder-trt-engine-path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="vocoder trt engine path",
|
||||
)
|
||||
parser.add_argument("--enable-warmup", action="store_true")
|
||||
parser.add_argument("--remove-input-padding", action="store_true")
|
||||
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
|
||||
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def padded_mel_batch(ref_mels, max_seq_len):
|
||||
padded_ref_mels = []
|
||||
for mel in ref_mels:
|
||||
# pad along the last dimension
|
||||
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
|
||||
padded_ref_mels.append(padded_ref_mel)
|
||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||
return padded_ref_mels
|
||||
|
||||
|
||||
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("data_collator")
|
||||
target_sample_rate = 24000
|
||||
target_rms = 0.1
|
||||
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for i, item in enumerate(batch):
|
||||
item_id, prompt_text, target_text = (
|
||||
item["id"],
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
)
|
||||
ids.append(item_id)
|
||||
reference_target_texts_list.append(prompt_text + target_text)
|
||||
|
||||
ref_audio_org, ref_sr = (
|
||||
item["prompt_audio"]["array"],
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
if ref_rms < target_rms:
|
||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
||||
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
ref_mel = ref_mel.squeeze()
|
||||
ref_mel_len = ref_mel.shape[0]
|
||||
assert ref_mel.shape[1] == 100
|
||||
|
||||
ref_mel_list.append(ref_mel)
|
||||
ref_mel_len_list.append(ref_mel_len)
|
||||
|
||||
estimated_reference_target_mel_len.append(
|
||||
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||||
)
|
||||
|
||||
max_seq_len = max(estimated_reference_target_mel_len)
|
||||
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
|
||||
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
||||
|
||||
for i, item in enumerate(text_pad_sequence):
|
||||
text_pad_sequence[i] = F.pad(
|
||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||||
)
|
||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
|
||||
text_pad_sequence = F.pad(
|
||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||||
)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return {
|
||||
"ids": ids,
|
||||
"ref_mel_batch": ref_mel_batch,
|
||||
"ref_mel_len_batch": ref_mel_len_batch,
|
||||
"text_pad_sequence": text_pad_sequence,
|
||||
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
|
||||
}
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
# Initialize process group with explicit device IDs
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
)
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def get_tokenizer(vocab_file_path: str):
|
||||
"""
|
||||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||||
- "char" for char-wise tokenizer, need .txt vocab_file
|
||||
- "byte" for utf-8 tokenizer
|
||||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
"""
|
||||
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
vocab_size = len(vocab_char_map)
|
||||
return vocab_char_map, vocab_size
|
||||
|
||||
|
||||
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||||
final_reference_target_texts_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
) # add custom trans here, to address oov
|
||||
|
||||
def is_chinese(c):
|
||||
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
||||
|
||||
for text in reference_target_texts_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
char_list.append(" ")
|
||||
char_list.extend(seg)
|
||||
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
||||
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
||||
for i, c in enumerate(seg):
|
||||
if is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.append(seg_[i])
|
||||
else: # if mixed characters, alphabets and symbols
|
||||
for c in seg:
|
||||
if ord(c) < 256:
|
||||
char_list.extend(c)
|
||||
elif is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
||||
else:
|
||||
char_list.append(c)
|
||||
final_reference_target_texts_list.append(char_list)
|
||||
|
||||
return final_reference_target_texts_list
|
||||
|
||||
|
||||
def list_str_to_idx(
|
||||
text: Union[List[str], List[List[str]]],
|
||||
vocab_char_map: Dict[str, int], # {char: idx}
|
||||
padding_value=-1,
|
||||
):
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return list_idx_tensors
|
||||
|
||||
|
||||
def load_vocoder(
|
||||
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
||||
):
|
||||
if vocoder_name == "vocos":
|
||||
if vocoder_trt_engine_path is not None:
|
||||
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
|
||||
else:
|
||||
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
||||
if is_local:
|
||||
print(f"Load vocos from local path {local_path}")
|
||||
config_path = f"{local_path}/config.yaml"
|
||||
model_path = f"{local_path}/pytorch_model.bin"
|
||||
else:
|
||||
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
||||
repo_id = "charactr/vocos-mel-24khz"
|
||||
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
||||
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
||||
vocoder = Vocos.from_hparams(config_path)
|
||||
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
||||
from vocos.feature_extractors import EncodecFeatures
|
||||
|
||||
if isinstance(vocoder.feature_extractor, EncodecFeatures):
|
||||
encodec_parameters = {
|
||||
"feature_extractor.encodec." + key: value
|
||||
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
||||
}
|
||||
state_dict.update(encodec_parameters)
|
||||
vocoder.load_state_dict(state_dict)
|
||||
vocoder = vocoder.eval().to(device)
|
||||
elif vocoder_name == "bigvgan":
|
||||
raise NotImplementedError("BigVGAN is not implemented yet")
|
||||
return vocoder
|
||||
|
||||
|
||||
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
|
||||
if vocoder == "vocos":
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=24000,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
n_mels=100,
|
||||
power=1,
|
||||
center=True,
|
||||
normalized=False,
|
||||
norm=None,
|
||||
).to(device)
|
||||
mel = mel_stft(waveform.to(device))
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel.transpose(1, 2)
|
||||
|
||||
|
||||
class VocosTensorRT:
|
||||
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
||||
logger.info(f"Loading vae engine from {engine_path}")
|
||||
self.engine_path = engine_path
|
||||
with open(engine_path, "rb") as f:
|
||||
engine_buffer = f.read()
|
||||
self.session = Session.from_serialized_engine(engine_buffer)
|
||||
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
|
||||
|
||||
def decode(self, mels):
|
||||
mels = mels.contiguous()
|
||||
inputs = {"mel": mels}
|
||||
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
|
||||
outputs = {
|
||||
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
|
||||
}
|
||||
ok = self.session.run(inputs, outputs, self.stream)
|
||||
|
||||
assert ok, "Runtime execution failed for vae session"
|
||||
|
||||
samples = outputs["waveform"]
|
||||
return samples
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
|
||||
|
||||
tllm_model_dir = args.tllm_model_dir
|
||||
config_file = os.path.join(tllm_model_dir, "config.json")
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
if args.backend_type == "trt":
|
||||
model = F5TTS(
|
||||
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
import sys
|
||||
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||||
from f5_tts.infer.utils_infer import load_model
|
||||
from f5_tts.model import DiT
|
||||
|
||||
F5TTS_model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
text_mask_padding=False,
|
||||
)
|
||||
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
|
||||
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
|
||||
)
|
||||
|
||||
dataset = load_dataset(
|
||||
"yuekai/seed_tts",
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
def add_estimated_duration(example):
|
||||
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
|
||||
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
|
||||
estimated_duration = prompt_audio_len * scale_factor
|
||||
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
|
||||
return example
|
||||
|
||||
dataset = dataset.map(add_estimated_duration)
|
||||
dataset = dataset.sort("estimated_duration", reverse=True)
|
||||
if args.use_perf:
|
||||
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
|
||||
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
|
||||
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
|
||||
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
|
||||
dataset = datasets.concatenate_datasets(dataset_list_short)
|
||||
if world_size > 1:
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
else:
|
||||
# This would disable shuffling
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
|
||||
)
|
||||
|
||||
total_steps = len(dataset)
|
||||
|
||||
if args.enable_warmup:
|
||||
for batch in dataloader:
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
if args.backend_type == "trt":
|
||||
_ = model.sample(
|
||||
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
with torch.inference_mode():
|
||||
text_pad_seq -= 1
|
||||
text_pad_seq[text_pad_seq == -2] = -1
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
steps=16,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
decoding_time = 0
|
||||
vocoder_time = 0
|
||||
total_duration = 0
|
||||
if args.use_perf:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
total_decoding_time = time.time()
|
||||
for batch in dataloader:
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_push("data sample")
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
if args.backend_type == "trt":
|
||||
generated, cost_time = model.sample(
|
||||
text_pad_seq,
|
||||
ref_mels,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
use_perf=args.use_perf,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
with torch.inference_mode():
|
||||
start_time = time.time()
|
||||
text_pad_seq -= 1
|
||||
text_pad_seq[text_pad_seq == -2] = -1
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=16,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
cost_time = time.time() - start_time
|
||||
decoding_time += cost_time
|
||||
vocoder_start_time = time.time()
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
if args.vocoder == "vocos":
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_push("vocoder decode")
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24_000
|
||||
# if ref_rms_list[i] < target_rms:
|
||||
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * target_rms / rms
|
||||
utt = batch["ids"][i]
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/{utt}.wav",
|
||||
generated_wave,
|
||||
target_sample_rate,
|
||||
)
|
||||
total_duration += generated_wave.shape[1] / target_sample_rate
|
||||
vocoder_time += time.time() - vocoder_start_time
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
total_decoding_time = time.time() - total_decoding_time
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
rtf = total_decoding_time / total_duration
|
||||
s = f"RTF: {rtf:.4f}\n"
|
||||
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||||
s += f"({total_duration / 3600:.2f} hours)\n"
|
||||
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
|
||||
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
|
||||
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
|
||||
s += f"batch size: {args.batch_size}\n"
|
||||
print(s)
|
||||
|
||||
with open(f"{args.output_dir}/rtf.txt", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -44,7 +44,6 @@ python3 client_grpc.py \
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
|
||||
@@ -23,10 +23,11 @@
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import tensorrt as trt
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from functools import wraps
|
||||
from typing import List, Optional
|
||||
|
||||
import tensorrt as trt
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
|
||||
|
||||
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
|
||||
@@ -24,16 +24,17 @@
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import json
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import torchaudio
|
||||
import jieba
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
import os
|
||||
|
||||
import jieba
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from f5_tts_trtllm import F5TTS
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
|
||||
|
||||
def get_tokenizer(vocab_file_path: str):
|
||||
@@ -219,7 +220,9 @@ class TritonPythonModel:
|
||||
|
||||
reference_mel_len.append(mel_features.shape[1])
|
||||
estimated_reference_target_mel_len.append(
|
||||
int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))
|
||||
int(
|
||||
mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
|
||||
)
|
||||
)
|
||||
|
||||
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
||||
|
||||
@@ -34,6 +34,7 @@ from .deepseek_v2.model import DeepseekV2ForCausalLM
|
||||
from .dit.model import DiT
|
||||
from .eagle.model import EagleForCausalLM
|
||||
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
|
||||
from .f5tts.model import F5TTS
|
||||
from .falcon.config import FalconConfig
|
||||
from .falcon.model import FalconForCausalLM, FalconModel
|
||||
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
|
||||
@@ -54,12 +55,12 @@ from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodi
|
||||
from .mpt.model import MPTForCausalLM, MPTModel
|
||||
from .nemotron_nas.model import DeciLMForCausalLM
|
||||
from .opt.model import OPTForCausalLM, OPTModel
|
||||
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
||||
from .phi.model import PhiForCausalLM, PhiModel
|
||||
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
||||
from .qwen.model import QWenForCausalLM
|
||||
from .recurrentgemma.model import RecurrentGemmaForCausalLM
|
||||
from .redrafter.model import ReDrafterForCausalLM
|
||||
from .f5tts.model import F5TTS
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BertModel",
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import tensorrt as trt
|
||||
from collections import OrderedDict
|
||||
from tensorrt_llm._common import default_net
|
||||
|
||||
from ..._utils import str_dtype_to_trt
|
||||
from ...functional import Tensor, concat
|
||||
from ...layers import Linear
|
||||
from ...module import Module, ModuleList
|
||||
from ...plugin import current_all_reduce_helper
|
||||
from ..modeling_utils import PretrainedConfig, PretrainedModel
|
||||
from ...functional import Tensor, concat
|
||||
from ...module import Module, ModuleList
|
||||
from tensorrt_llm._common import default_net
|
||||
from ...layers import Linear
|
||||
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
|
||||
|
||||
from .modules import (
|
||||
TimestepEmbedding,
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
)
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
parent_dir = os.path.dirname(current_file_path)
|
||||
|
||||
@@ -3,33 +3,35 @@ from __future__ import annotations
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
from tensorrt_llm._common import default_net
|
||||
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
|
||||
|
||||
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
|
||||
from ...functional import (
|
||||
Tensor,
|
||||
bert_attention,
|
||||
cast,
|
||||
chunk,
|
||||
concat,
|
||||
constant,
|
||||
expand,
|
||||
expand_dims,
|
||||
expand_dims_like,
|
||||
expand_mask,
|
||||
gelu,
|
||||
matmul,
|
||||
permute,
|
||||
shape,
|
||||
silu,
|
||||
slice,
|
||||
permute,
|
||||
expand_mask,
|
||||
expand_dims_like,
|
||||
unsqueeze,
|
||||
matmul,
|
||||
softmax,
|
||||
squeeze,
|
||||
cast,
|
||||
gelu,
|
||||
unsqueeze,
|
||||
view,
|
||||
)
|
||||
from ...functional import expand_dims, view, bert_attention
|
||||
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
|
||||
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
|
||||
from ...module import Module
|
||||
|
||||
|
||||
|
||||
24
src/f5_tts/runtime/triton_trtllm/requirements-pytorch.txt
Normal file
24
src/f5_tts/runtime/triton_trtllm/requirements-pytorch.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
accelerate>=0.33.0
|
||||
bitsandbytes>0.37.0
|
||||
cached_path
|
||||
click
|
||||
datasets
|
||||
ema_pytorch>=0.5.2
|
||||
gradio>=3.45.2
|
||||
hydra-core>=1.3.0
|
||||
jieba
|
||||
librosa
|
||||
matplotlib
|
||||
numpy<=1.26.4
|
||||
pydub
|
||||
pypinyin
|
||||
safetensors
|
||||
soundfile
|
||||
tomli
|
||||
torch>=2.0.0
|
||||
# torchaudio>=2.0.0
|
||||
torchdiffeq
|
||||
tqdm>=4.65.0
|
||||
transformers
|
||||
x_transformers>=1.31.14
|
||||
packaging>=24.2
|
||||
@@ -2,8 +2,8 @@ stage=$1
|
||||
stop_stage=$2
|
||||
model=$3 # F5TTS_Base
|
||||
if [ -z "$model" ]; then
|
||||
echo "Model is none"
|
||||
exit 1
|
||||
echo "Model is none, using default model F5TTS_Base"
|
||||
model=F5TTS_Base
|
||||
fi
|
||||
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
@@ -68,3 +68,43 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "TRT-LLM: offline decoding benchmark test"
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
echo "Native Pytorch: offline decoding benchmark test"
|
||||
pip install -r requirements-pytorch.txt
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=pytorch
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--split-name $split_name \
|
||||
--enable-warmup \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
fi
|
||||
@@ -40,6 +40,7 @@ import torch as th
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import check_COLA, get_window
|
||||
|
||||
|
||||
support_clp_op = None
|
||||
if th.__version__ >= "1.7.0":
|
||||
from torch.fft import rfft as fft
|
||||
|
||||
@@ -8,7 +8,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import str_dtype_to_torch
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
|
||||
|
||||
@@ -12,13 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from conv_stft import STFT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from vocos import Vocos
|
||||
import argparse
|
||||
|
||||
|
||||
opset_version = 17
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from f5_tts.model import CFM, DiT
|
||||
|
||||
import torch
|
||||
import thop
|
||||
import torch
|
||||
|
||||
from f5_tts.model import CFM, DiT
|
||||
|
||||
|
||||
""" ~155M """
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import socket
|
||||
import asyncio
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
import numpy as np
|
||||
import queue
|
||||
import socket
|
||||
import struct
|
||||
@@ -10,6 +9,7 @@ import traceback
|
||||
import wave
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -18,12 +18,13 @@ from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
chunk_text,
|
||||
preprocess_ref_audio_text,
|
||||
load_vocoder,
|
||||
load_model,
|
||||
infer_batch_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
preprocess_ref_audio_text,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import subprocess # For invoking ffprobe
|
||||
import shutil
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess # For invoking ffprobe
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import argparse
|
||||
@@ -16,12 +17,10 @@ from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import (
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
|
||||
|
||||
@@ -7,20 +7,18 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import (
|
||||
repetition_found,
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
|
||||
|
||||
|
||||
out_zh = {
|
||||
|
||||
94
src/f5_tts/train/datasets/prepare_emilia_v2.py
Normal file
94
src/f5_tts/train/datasets/prepare_emilia_v2.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
|
||||
# prepares Emilia dataset with the new format w/ Emilia-YODAS
|
||||
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import repetition_found
|
||||
|
||||
|
||||
# Define filters for exclusion
|
||||
out_en = set()
|
||||
en_filters = ["ا", "い", "て"]
|
||||
|
||||
|
||||
def process_audio_directory(audio_dir):
|
||||
sub_result, durations, vocab_set = [], [], set()
|
||||
bad_case_en = 0
|
||||
|
||||
for file in audio_dir.iterdir():
|
||||
if file.suffix == ".json":
|
||||
with open(file, "r") as f:
|
||||
obj = json.load(f)
|
||||
text = obj["text"]
|
||||
if any(f in text for f in en_filters) or repetition_found(text, length=4):
|
||||
bad_case_en += 1
|
||||
continue
|
||||
|
||||
duration = obj["duration"]
|
||||
audio_file = file.with_suffix(".mp3")
|
||||
if audio_file.exists():
|
||||
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
|
||||
durations.append(duration)
|
||||
vocab_set.update(list(text))
|
||||
|
||||
return sub_result, durations, vocab_set, bad_case_en
|
||||
|
||||
|
||||
def main():
|
||||
assert tokenizer in ["pinyin", "char"]
|
||||
result, duration_list, text_vocab_set = [], [], set()
|
||||
total_bad_case_en = 0
|
||||
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
dataset_path = Path(dataset_dir)
|
||||
for sub_dir in dataset_path.iterdir():
|
||||
if sub_dir.is_dir():
|
||||
futures.append(executor.submit(process_audio_directory, sub_dir))
|
||||
|
||||
for future in tqdm(futures, total=len(futures)):
|
||||
sub_result, durations, vocab_set, bad_case_en = future.result()
|
||||
result.extend(sub_result)
|
||||
duration_list.extend(durations)
|
||||
text_vocab_set.update(vocab_set)
|
||||
total_bad_case_en += bad_case_en
|
||||
|
||||
executor.shutdown()
|
||||
|
||||
if not os.path.exists(f"{save_dir}"):
|
||||
os.makedirs(f"{save_dir}")
|
||||
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
with open(f"{save_dir}/vocab.txt", "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
|
||||
print(f"For {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
max_workers = 32
|
||||
tokenizer = "char"
|
||||
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
|
||||
dataset_name = f"Emilia_EN_{tokenizer}"
|
||||
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
|
||||
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
||||
|
||||
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
|
||||
main()
|
||||
@@ -1,15 +1,17 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
import soundfile as sf
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def deal_with_audio_dir(audio_dir):
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
import soundfile as sf
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -4,15 +4,16 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from tqdm import tqdm
|
||||
|
||||
import torchaudio
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from importlib.resources import files
|
||||
|
||||
from cached_path import cached_path
|
||||
|
||||
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
# -------------------------- Dataset Settings --------------------------- #
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import gc
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import platform
|
||||
import psutil
|
||||
import queue
|
||||
import random
|
||||
import re
|
||||
import signal
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -16,21 +14,23 @@ import threading
|
||||
import time
|
||||
from glob import glob
|
||||
from importlib.resources import files
|
||||
from scipy.io import wavfile
|
||||
|
||||
import click
|
||||
import gradio as gr
|
||||
import librosa
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from datasets import Dataset as Dataset_
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from safetensors.torch import load_file, save_file
|
||||
from scipy.io import wavfile
|
||||
|
||||
from f5_tts.api import F5TTS
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
from f5_tts.infer.utils_infer import transcribe
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
training_process = None
|
||||
@@ -138,6 +138,8 @@ def load_settings(project_name):
|
||||
"logger": "none",
|
||||
"bnb_optimizer": False,
|
||||
}
|
||||
if device == "mps":
|
||||
default_settings["mixed_precision"] = "none"
|
||||
|
||||
# Load settings from file if it exists
|
||||
if os.path.isfile(file_setting):
|
||||
|
||||
@@ -10,6 +10,7 @@ from f5_tts.model import CFM, Trainer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user