runtime trtllm: clean-up v0 code, several fixes.

This commit is contained in:
SWivid
2025-10-20 10:30:58 +00:00
parent 65ada48a62
commit 8d3ec72159
14 changed files with 239 additions and 320 deletions

View File

@@ -20,7 +20,6 @@ from f5_tts.model.modules import (
ConvPositionEmbedding,
DiTBlock,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -89,8 +88,7 @@ class TextEmbedding(nn.Module):
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
if self.mask_padding:
text_mask = text == 0
@@ -102,10 +100,7 @@ class TextEmbedding(nn.Module):
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
text = text + self.freqs_cis[:seq_len, :]
# convnextv2 blocks
if self.mask_padding:
@@ -242,6 +237,7 @@ class DiT(nn.Module):
audio_mask: bool["b n"] | None = None, # noqa: F722
):
seq_len = x.shape[1]
# TODO. modify to get text_embed one by one (to avoid misalignment when batching), as done in runtime imple.
if cache:
if drop_text:
if self.text_uncond is None:

View File

@@ -252,10 +252,9 @@ class CFM(nn.Module):
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
if not exists(lens): # if lens not acquired by trainer from collate_fn
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
mask = lens_to_mask(lens, length=seq_len)
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)

View File

@@ -0,0 +1,3 @@
# runtime/triton_trtllm related
model.cache
model_repo/

View File

@@ -3,8 +3,7 @@
### Quick Start
Directly launch the service using docker compose.
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
MODEL=F5TTS_v1_Base docker compose up
```
### Build Image
@@ -20,10 +19,12 @@ docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
bash run.sh 0 4 F5TTS_v1_Base
```
> [!NOTE]
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`. Remember to used matched model version (`F5TTS_v1_Base` for v1, `F5TTS_Base` for v0).
### HTTP Client
```sh
@@ -49,11 +50,11 @@ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
```
### Benchmark Results
@@ -66,4 +67,5 @@ Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pair
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
1. [Yuekai Zhang](https://github.com/yuekaizhang)
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,39 +19,45 @@ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
"""
import argparse
import importlib
import json
import os
import sys
import time
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.eval.utils_eval import padded_mel_batch
from f5_tts.model.modules import get_vocos_mel_spectrogram
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
torch.manual_seed(0)
@@ -111,22 +117,20 @@ def get_args():
return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
(
ids,
ref_rms_list,
ref_mel_list,
ref_mel_len_list,
estimated_reference_target_mel_len,
reference_target_texts_list,
) = (
[],
[],
[],
[],
@@ -148,6 +152,7 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
ref_rms_list.append(ref_rms)
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
@@ -159,40 +164,31 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
ref_audio = ref_audio.to("cuda")
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze()
ref_mel_len = ref_mel.shape[0]
assert ref_mel.shape[1] == 100
ref_mel_len = ref_mel.shape[-1]
assert ref_mel.shape[0] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
max_seq_len = max(estimated_reference_target_mel_len)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_rms_list": ref_rms_list,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
@@ -216,72 +212,6 @@ def init_distributed():
return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
@@ -316,29 +246,11 @@ def load_vocoder(
return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}")
logger.info(f"Loading vocoder engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
@@ -368,20 +280,20 @@ def main():
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
if args.backend_type == "trt":
tllm_model_dir = args.tllm_model_dir
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
tllm_model_config,
debug_mode=False,
tllm_model_dir=tllm_model_dir,
model_path=args.model_path,
vocab_size=vocab_size,
)
elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
@@ -445,20 +357,23 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
)
elif args.backend_type == "pytorch":
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=16,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
@@ -478,13 +393,13 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
ref_mels,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
@@ -494,20 +409,20 @@ def main():
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
target_rms = 0.1
target_sample_rate = 24000
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
@@ -519,13 +434,10 @@ def main():
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
if batch["ref_rms_list"][i] < target_rms:
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",

View File

@@ -30,15 +30,6 @@ python3 client_grpc.py \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
@@ -176,8 +167,7 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
help="triton model_repo module name to request",
)
parser.add_argument(
@@ -206,7 +196,7 @@ def get_args():
"--log-dir",
type=str,
required=False,
default="./tmp",
default="./tests/client_grpc",
help="log directory",
)

View File

@@ -24,6 +24,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import os
import numpy as np
import requests
@@ -65,14 +66,13 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
default="tests/client_http.wav",
help="Path to save the output audio",
)
return parser.parse_args()
@@ -140,4 +140,5 @@ if __name__ == "__main__":
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -12,6 +12,7 @@ import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
from torch.nn.utils.rnn import pad_sequence
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
@@ -38,20 +39,15 @@ class TextEmbedding(nn.Module):
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
def forward(self, text, seq_len):
text = text + 1
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
text = self.text_blocks(text)
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
@@ -112,20 +108,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
def get_text_embed_dict(ckpt_path, use_ema=True):
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model_params = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
for key in model_params.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
return text_embed_dict
@@ -196,15 +205,14 @@ class F5TTS(object):
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
# self.target_audio_sample_rate = 24000
# self.target_rms = 0.1 # least rms when inference, normalize to if lower
# self.n_fft = 1024
# self.win_length = 1024
# self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
@@ -214,12 +222,21 @@ class F5TTS(object):
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
self.nfe_steps = 32
epss = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
time_step = 1 - torch.cos(torch.pi * t / 2)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
tmp_dim = 256 # WAR: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
@@ -344,7 +361,7 @@ class F5TTS(object):
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
cond_pad_sequence: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
@@ -353,26 +370,43 @@ class F5TTS(object):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
max_seq_len = cond_pad_sequence.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
# get text_embed one by one to avoid misalignment
text_and_drop_embedding_list = []
for i in range(batch):
text_and_drop_embedding_i = self.text_embedding(
torch.cat(
(
text_pad_sequence[i].unsqueeze(0).to(self.device),
torch.full((1, text_pad_sequence.shape[1]), -1, dtype=torch.int32).to(self.device),
),
dim=0,
),
estimated_reference_target_mel_len[i],
)
text_and_drop_embedding_list.extend([text_and_drop_embedding_i[0], text_and_drop_embedding_i[1]])
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
# pad separately computed text_embed to form batch with max_seq_len
text_and_drop_embedding = pad_sequence(
text_and_drop_embedding_list,
batch_first=True,
padding_value=0,
)
text_embedding = text_and_drop_embedding[0::2]
text_embedding_drop = text_and_drop_embedding[1::2]
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
noise = torch.randn_like(cond_pad_sequence).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text = torch.cat(
(
cond_pad_sequence,
text_embedding,
),
dim=-1,
)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),

View File

@@ -28,7 +28,6 @@ import os
import jieba
import torch
import torch.nn.functional as F
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
@@ -99,7 +98,8 @@ def list_str_to_idx(
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
class TritonPythonModel:
@@ -107,12 +107,12 @@ class TritonPythonModel:
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.target_rms = 0.1 # least rms when inference, normalize to if lower
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.max_mel_len = 4096
self.head_dim = 64
parameters = json.loads(args["model_config"])["parameters"]
@@ -181,7 +181,8 @@ class TritonPythonModel:
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
reference_rms_list,
) = [], [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
@@ -208,6 +209,7 @@ class TritonPythonModel:
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
reference_rms_list.append(ref_rms)
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
@@ -237,15 +239,6 @@ class TritonPythonModel:
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
@@ -266,9 +259,8 @@ class TritonPythonModel:
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
if reference_rms_list[i] < self.target_rms:
audio = audio * reference_rms_list[i] / self.target_rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])

View File

@@ -80,7 +80,7 @@ class F5TTS(PretrainedModel):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
max_seq_len = 4096
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size

View File

@@ -1,24 +0,0 @@
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
# torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
packaging>=24.2

View File

@@ -1,64 +1,66 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
model=$3 # F5TTS_v1_Base | F5TTS_Base
if [ -z "$model" ]; then
echo "Model is none, using default model F5TTS_Base"
model=F5TTS_Base
model=F5TTS_v1_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
CKPT_DIR=../../../../ckpts
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
MODEL_REPO=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
echo "Downloading F5-TTS from huggingface"
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
fi
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
vocab_file=$CKPT_DIR/$model/vocab.txt
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python3 scripts/convert_checkpoint.py \
--timm_ckpt $ckpt_file \
--output_dir $TRTLLM_CKPT_DIR --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
rm -r $MODEL_REPO
cp -r ./model_repo_f5_tts $MODEL_REPO
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
tritonserver --model-repository=$MODEL_REPO
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
split_name=wenetspeech4tts
log_dir=./tests/client_grpc_concurrent_${num_task}_${split_name}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -74,37 +76,37 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
pip install -r requirements-pytorch.txt
batch_size=1
if ! python3 -c "import f5_tts" &> /dev/null; then
pip install -e ../../../../
fi
batch_size=1 # set attn_mask_enabled=True if batched
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi

View File

@@ -172,11 +172,12 @@ def parse_arguments():
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
default="F5TTS_v1_Base",
choices=[
"F5TTS_v1_Base",
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
)
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
@@ -184,7 +185,6 @@ def parse_arguments():
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
@@ -197,18 +197,29 @@ def parse_arguments():
return args
def convert_timm_dit(args, mapping, dtype="float32"):
def convert_timm_dit(args, mapping, dtype="float32", use_ema=True):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
ckpt_path = args.timm_ckpt
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
model_params = load_file(ckpt_path)
else:
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
prefix = "ema_model.transformer." if use_ema else "transformer."
if any(k.startswith(prefix) for k in model_params.keys()):
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
key[len(prefix) :] if key.startswith(prefix) else key: value
for key, value in model_params.items()
if key.startswith(prefix)
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
@@ -230,7 +241,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
# new_prefix = "f5_transformer."
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
@@ -278,7 +289,7 @@ def save_config(args):
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"dropout": 0.0, # 0.1
"ff_mult": 2,
"mel_dim": 100,
"text_num_embeds": 256,
@@ -296,7 +307,7 @@ def save_config(args):
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
# "exclude_modules": "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Manual installation of TensorRT, in case not using NVIDIA NGC:
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
@@ -40,4 +42,3 @@ ${TRTEXEC} \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}