runtime trtllm: clean-up v0 code, several fixes.

This commit is contained in:
SWivid
2025-10-20 10:30:58 +00:00
parent 65ada48a62
commit 8d3ec72159
14 changed files with 239 additions and 320 deletions

View File

@@ -20,7 +20,6 @@ from f5_tts.model.modules import (
ConvPositionEmbedding, ConvPositionEmbedding,
DiTBlock, DiTBlock,
TimestepEmbedding, TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis, precompute_freqs_cis,
) )
@@ -89,8 +88,7 @@ class TextEmbedding(nn.Module):
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722 def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
if self.mask_padding: if self.mask_padding:
text_mask = text == 0 text_mask = text == 0
@@ -102,10 +100,7 @@ class TextEmbedding(nn.Module):
# possible extra modeling # possible extra modeling
if self.extra_modeling: if self.extra_modeling:
# sinus pos emb # sinus pos emb
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long) text = text + self.freqs_cis[:seq_len, :]
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks # convnextv2 blocks
if self.mask_padding: if self.mask_padding:
@@ -242,6 +237,7 @@ class DiT(nn.Module):
audio_mask: bool["b n"] | None = None, # noqa: F722 audio_mask: bool["b n"] | None = None, # noqa: F722
): ):
seq_len = x.shape[1] seq_len = x.shape[1]
# TODO. modify to get text_embed one by one (to avoid misalignment when batching), as done in runtime imple.
if cache: if cache:
if drop_text: if drop_text:
if self.text_uncond is None: if self.text_uncond is None:

View File

@@ -252,10 +252,9 @@ class CFM(nn.Module):
assert text.shape[0] == batch assert text.shape[0] == batch
# lens and mask # lens and mask
if not exists(lens): if not exists(lens): # if lens not acquired by trainer from collate_fn
lens = torch.full((batch,), seq_len, device=device) lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
# get a random span to mask out for training conditionally # get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask) frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)

View File

@@ -0,0 +1,3 @@
# runtime/triton_trtllm related
model.cache
model_repo/

View File

@@ -3,8 +3,7 @@
### Quick Start ### Quick Start
Directly launch the service using docker compose. Directly launch the service using docker compose.
```sh ```sh
# TODO: support F5TTS_v1_Base MODEL=F5TTS_v1_Base docker compose up
MODEL=F5TTS_Base docker compose up
``` ```
### Build Image ### Build Image
@@ -20,10 +19,12 @@ docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm
``` ```
### Export Models to TensorRT-LLM and Launch Server ### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper). Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
```sh ```sh
bash run.sh 0 4 F5TTS_Base bash run.sh 0 4 F5TTS_v1_Base
``` ```
> [!NOTE]
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`. Remember to used matched model version (`F5TTS_v1_Base` for v1, `F5TTS_Base` for v0).
### HTTP Client ### HTTP Client
```sh ```sh
@@ -49,11 +50,11 @@ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \ --batch-size $batch_size \
--enable-warmup \ --enable-warmup \
--split-name $split_name \ --split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \ --model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \ --vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \ --vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \ --backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1 --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
``` ```
### Benchmark Results ### Benchmark Results
@@ -66,4 +67,5 @@ Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pair
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch | | F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits ### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) 1. [Yuekai Zhang](https://github.com/yuekaizhang)
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song) # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang # 2025 (authors: Yuekai Zhang)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -19,39 +19,45 @@ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \ --batch-size $batch_size \
--enable-warmup \ --enable-warmup \
--split-name $split_name \ --split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \ --model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \ --vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \ --vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \ --backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1 --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
""" """
import argparse import argparse
import importlib
import json import json
import os import os
import sys
import time import time
from typing import Dict, List, Union
import datasets import datasets
import jieba
import tensorrt as trt import tensorrt as trt
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from datasets import load_dataset from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from vocos import Vocos from vocos import Vocos
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.eval.utils_eval import padded_mel_batch
from f5_tts.model.modules import get_vocos_mel_spectrogram
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
torch.manual_seed(0) torch.manual_seed(0)
@@ -111,22 +117,20 @@ def get_args():
return args return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf: if use_perf:
torch.cuda.nvtx.range_push("data_collator") torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000 target_sample_rate = 24000
target_rms = 0.1 target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = ( (
ids,
ref_rms_list,
ref_mel_list,
ref_mel_len_list,
estimated_reference_target_mel_len,
reference_target_texts_list,
) = (
[],
[], [],
[], [],
[], [],
@@ -148,6 +152,7 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
) )
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
ref_rms_list.append(ref_rms)
if ref_rms < target_rms: if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms ref_audio_org = ref_audio_org * target_rms / ref_rms
@@ -159,40 +164,31 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf: if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}") torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda") ref_audio = ref_audio.to("cuda")
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
if use_perf: if use_perf:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze() ref_mel_len = ref_mel.shape[-1]
ref_mel_len = ref_mel.shape[0] assert ref_mel.shape[0] == 100
assert ref_mel.shape[1] == 100
ref_mel_list.append(ref_mel) ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len) ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append( estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8")))) int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
) )
max_seq_len = max(estimated_reference_target_mel_len) ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map) text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf: if use_perf:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return { return {
"ids": ids, "ids": ids,
"ref_rms_list": ref_rms_list,
"ref_mel_batch": ref_mel_batch, "ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch, "ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence, "text_pad_sequence": text_pad_sequence,
@@ -216,72 +212,6 @@ def init_distributed():
return world_size, local_rank, rank return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder( def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
): ):
@@ -316,29 +246,11 @@ def load_vocoder(
return vocoder return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT: class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None): def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}") logger.info(f"Loading vocoder engine from {engine_path}")
self.engine_path = engine_path self.engine_path = engine_path
with open(engine_path, "rb") as f: with open(engine_path, "rb") as f:
engine_buffer = f.read() engine_buffer = f.read()
@@ -368,20 +280,20 @@ def main():
world_size, local_rank, rank = init_distributed() world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file) vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
if args.backend_type == "trt": if args.backend_type == "trt":
tllm_model_dir = args.tllm_model_dir
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
model = F5TTS( model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size tllm_model_config,
debug_mode=False,
tllm_model_dir=tllm_model_dir,
model_path=args.model_path,
vocab_size=vocab_size,
) )
elif args.backend_type == "pytorch": elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT from f5_tts.model import DiT
@@ -445,20 +357,23 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device) text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"] total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.backend_type == "trt": if args.backend_type == "trt":
_ = model.sample( _ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
) )
elif args.backend_type == "pytorch": elif args.backend_type == "pytorch":
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device) total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
generated, _ = model.sample( generated, _ = model.sample(
cond=ref_mels, cond=ref_mels,
text=text_pad_seq, text=text_pad_seq,
duration=total_mel_lens, duration=total_mel_lens,
steps=16, steps=32,
cfg_strength=2.0, cfg_strength=2.0,
sway_sampling_coef=-1, sway_sampling_coef=-1,
) )
@@ -478,13 +393,13 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device) text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"] total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.use_perf: if args.use_perf:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
if args.backend_type == "trt": if args.backend_type == "trt":
generated, cost_time = model.sample( generated, cost_time = model.sample(
text_pad_seq, text_pad_seq,
ref_mels, cond_pad_seq,
ref_mel_lens, ref_mel_lens,
total_mel_lens, total_mel_lens,
remove_input_padding=args.remove_input_padding, remove_input_padding=args.remove_input_padding,
@@ -494,20 +409,20 @@ def main():
total_mel_lens = torch.tensor(total_mel_lens, device=device) total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode(): with torch.inference_mode():
start_time = time.time() start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample( generated, _ = model.sample(
cond=ref_mels, cond=ref_mels,
text=text_pad_seq, text=text_pad_seq,
duration=total_mel_lens, duration=total_mel_lens,
lens=ref_mel_lens, lens=ref_mel_lens,
steps=16, steps=32,
cfg_strength=2.0, cfg_strength=2.0,
sway_sampling_coef=-1, sway_sampling_coef=-1,
) )
cost_time = time.time() - start_time cost_time = time.time() - start_time
decoding_time += cost_time decoding_time += cost_time
vocoder_start_time = time.time() vocoder_start_time = time.time()
target_rms = 0.1
target_sample_rate = 24000
for i, gen in enumerate(generated): for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
@@ -519,13 +434,10 @@ def main():
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
else: else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000 if batch["ref_rms_list"][i] < target_rms:
# if ref_rms_list[i] < target_rms: generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
utt = batch["ids"][i] utt = batch["ids"][i]
torchaudio.save( torchaudio.save(
f"{args.output_dir}/{utt}.wav", f"{args.output_dir}/{utt}.wav",

View File

@@ -30,15 +30,6 @@ python3 client_grpc.py \
--huggingface-dataset yuekai/seed_tts \ --huggingface-dataset yuekai/seed_tts \
--split-name test_zh \ --split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task} --log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
""" """
import argparse import argparse
@@ -176,8 +167,7 @@ def get_args():
"--model-name", "--model-name",
type=str, type=str,
default="f5_tts", default="f5_tts",
choices=["f5_tts", "spark_tts"], help="triton model_repo module name to request",
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
) )
parser.add_argument( parser.add_argument(
@@ -206,7 +196,7 @@ def get_args():
"--log-dir", "--log-dir",
type=str, type=str,
required=False, required=False,
default="./tmp", default="./tests/client_grpc",
help="log directory", help="log directory",
) )

View File

@@ -24,6 +24,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse import argparse
import os
import numpy as np import numpy as np
import requests import requests
@@ -65,14 +66,13 @@ def get_args():
"--model-name", "--model-name",
type=str, type=str,
default="f5_tts", default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request", help="triton model_repo module name to request",
) )
parser.add_argument( parser.add_argument(
"--output-audio", "--output-audio",
type=str, type=str,
default="output.wav", default="tests/client_http.wav",
help="Path to save the output audio", help="Path to save the output audio",
) )
return parser.parse_args() return parser.parse_args()
@@ -140,4 +140,5 @@ if __name__ == "__main__":
result = rsp.json() result = rsp.json()
audio = result["outputs"][0]["data"] audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32) audio = np.array(audio, dtype=np.float32)
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
sf.write(args.output_audio, audio, 24000, "PCM_16") sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -12,6 +12,7 @@ import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session from tensorrt_llm.runtime.session import Session
from torch.nn.utils.rnn import pad_sequence
def remove_tensor_padding(input_tensor, input_tensor_lengths=None): def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
@@ -38,20 +39,15 @@ class TextEmbedding(nn.Module):
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False) self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text): def forward(self, text, seq_len):
# only keep tensors with value not -1 text = text + 1
text_mask = text != -1 text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text_pad_cut_off_index = text_mask.sum(dim=1).max() text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
text = self.text_blocks(text)
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text return text
@@ -112,20 +108,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
return torch.cat([freqs_cos, freqs_sin], dim=-1) return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True): def get_text_embed_dict(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True) ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if use_ema: if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = { checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items() for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"] if k not in ["initted", "step"]
} }
dict_state = checkpoint["model_state_dict"] else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model_params = checkpoint["model_state_dict"]
text_embed_dict = {} text_embed_dict = {}
for key in dict_state.keys(): for key in model_params.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight # transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key: if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key] text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
return text_embed_dict return text_embed_dict
@@ -196,15 +205,14 @@ class F5TTS(object):
self.text_embedding = TextEmbedding( self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
).to(self.device) ).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True) self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
self.target_audio_sample_rate = 24000 # self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio # self.target_rms = 0.1 # least rms when inference, normalize to if lower
self.n_fft = 1024 # self.n_fft = 1024
self.win_length = 1024 # self.win_length = 1024
self.hop_length = 256 # self.hop_length = 256
self.n_mel_channels = 100 self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64 self.head_dim = 64
self.base_rescale_factor = 1.0 self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0 self.interpolation_factor = 1.0
@@ -214,12 +222,21 @@ class F5TTS(object):
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0) self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half() self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half() self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32) self.nfe_steps = 32
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t) epss = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
time_step = 1 - torch.cos(torch.pi * t / 2)
delta_t = torch.diff(time_step) delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256 tmp_dim = 256 # WAR: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32) time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2 half_dim = tmp_dim // 2
emb_factor = math.log(10000) / (half_dim - 1) emb_factor = math.log(10000) / (half_dim - 1)
@@ -344,7 +361,7 @@ class F5TTS(object):
def sample( def sample(
self, self,
text_pad_sequence: torch.Tensor, text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor, cond_pad_sequence: torch.Tensor,
ref_mel_len_batch: torch.Tensor, ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int], estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False, remove_input_padding: bool = False,
@@ -353,26 +370,43 @@ class F5TTS(object):
if use_perf: if use_perf:
torch.cuda.nvtx.range_push("text embedding") torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0] batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1] max_seq_len = cond_pad_sequence.shape[1]
text_pad_sequence_drop = torch.cat( # get text_embed one by one to avoid misalignment
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0 text_and_drop_embedding_list = []
for i in range(batch):
text_and_drop_embedding_i = self.text_embedding(
torch.cat(
(
text_pad_sequence[i].unsqueeze(0).to(self.device),
torch.full((1, text_pad_sequence.shape[1]), -1, dtype=torch.int32).to(self.device),
),
dim=0,
),
estimated_reference_target_mel_len[i],
) )
text_and_drop_embedding_list.extend([text_and_drop_embedding_i[0], text_and_drop_embedding_i[1]])
text_embedding_drop_list = [] # pad separately computed text_embed to form batch with max_seq_len
for i in range(batch + 1): text_and_drop_embedding = pad_sequence(
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device))) text_and_drop_embedding_list,
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0) batch_first=True,
padding_value=0,
)
text_embedding = text_and_drop_embedding[0::2]
text_embedding_drop = text_and_drop_embedding[1::2]
text_embedding = text_embedding_drop_condition[:-1] noise = torch.randn_like(cond_pad_sequence).to(self.device)
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1) rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1) rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1) cat_mel_text = torch.cat(
(
cond_pad_sequence,
text_embedding,
),
dim=-1,
)
cat_mel_text_drop = torch.cat( cat_mel_text_drop = torch.cat(
( (
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device), torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),

View File

@@ -28,7 +28,6 @@ import os
import jieba import jieba
import torch import torch
import torch.nn.functional as F
import torchaudio import torchaudio
import triton_python_backend_utils as pb_utils import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS from f5_tts_trtllm import F5TTS
@@ -99,7 +98,8 @@ def list_str_to_idx(
padding_value=-1, padding_value=-1,
): # noqa: F722 ): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
class TritonPythonModel: class TritonPythonModel:
@@ -107,12 +107,12 @@ class TritonPythonModel:
self.use_perf = True self.use_perf = True
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000 self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio self.target_rms = 0.1 # least rms when inference, normalize to if lower
self.n_fft = 1024 self.n_fft = 1024
self.win_length = 1024 self.win_length = 1024
self.hop_length = 256 self.hop_length = 256
self.n_mel_channels = 100 self.n_mel_channels = 100
self.max_mel_len = 3000 self.max_mel_len = 4096
self.head_dim = 64 self.head_dim = 64
parameters = json.loads(args["model_config"])["parameters"] parameters = json.loads(args["model_config"])["parameters"]
@@ -181,7 +181,8 @@ class TritonPythonModel:
reference_target_texts_list, reference_target_texts_list,
estimated_reference_target_mel_len, estimated_reference_target_mel_len,
reference_mel_len, reference_mel_len,
) = [], [], [], [], [] reference_rms_list,
) = [], [], [], [], [], []
mel_features_list = [] mel_features_list = []
if self.use_perf: if self.use_perf:
torch.cuda.nvtx.range_push("preprocess") torch.cuda.nvtx.range_push("preprocess")
@@ -208,6 +209,7 @@ class TritonPythonModel:
ref_rms = torch.sqrt(torch.mean(torch.square(wav))) ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms: if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms wav = wav * self.target_rms / ref_rms
reference_rms_list.append(ref_rms)
if self.reference_sample_rate != self.target_audio_sample_rate: if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav) wav = self.resampler(wav)
wav = wav.to(self.device) wav = wav.to(self.device)
@@ -237,15 +239,6 @@ class TritonPythonModel:
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map) text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf: if self.use_perf:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
@@ -266,9 +259,8 @@ class TritonPythonModel:
estimated_mel_len = estimated_reference_target_mel_len[i] estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2) denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item) audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio))) if reference_rms_list[i] < self.target_rms:
if rms < self.target_rms: audio = audio * reference_rms_list[i] / self.target_rms
audio = audio * self.target_rms / rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio]) inference_response = pb_utils.InferenceResponse(output_tensors=[audio])

View File

@@ -80,7 +80,7 @@ class F5TTS(PretrainedModel):
max_batch_size = kwargs["max_batch_size"] max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size] batch_size_range = [2, 2, max_batch_size]
mel_size = 100 mel_size = 100
max_seq_len = 3000 max_seq_len = 4096
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size] num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512 hidden_size = 512
concat_feature_dim = mel_size + hidden_size concat_feature_dim = mel_size + hidden_size

View File

@@ -1,24 +0,0 @@
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
# torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
packaging>=24.2

View File

@@ -1,64 +1,66 @@
stage=$1 stage=$1
stop_stage=$2 stop_stage=$2
model=$3 # F5TTS_Base model=$3 # F5TTS_v1_Base | F5TTS_Base
if [ -z "$model" ]; then if [ -z "$model" ]; then
echo "Model is none, using default model F5TTS_Base" model=F5TTS_v1_Base
model=F5TTS_Base
fi fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model" echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS CKPT_DIR=../../../../ckpts
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
vocoder_trt_engine_path=vocos_vocoder.plan VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
model_repo=./model_repo VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
MODEL_REPO=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface" echo "Downloading F5-TTS from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
fi fi
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
vocab_file=$CKPT_DIR/$model/vocab.txt
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint" echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \ python3 scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \ --timm_ckpt $ckpt_file \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model --output_dir $TRTLLM_CKPT_DIR --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \ trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
--max_batch_size 8 \ --max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable --output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder" echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server" echo "Building triton server"
rm -r $model_repo rm -r $MODEL_REPO
cp -r ./model_repo_f5_tts $model_repo cp -r ./model_repo_f5_tts $MODEL_REPO
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server" echo "Starting triton server"
tritonserver --model-repository=$model_repo tritonserver --model-repository=$MODEL_REPO
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server" echo "Testing triton server"
num_task=1 num_task=1
log_dir=./log_concurrent_tasks_${num_task} split_name=wenetspeech4tts
log_dir=./tests/client_grpc_concurrent_${num_task}_${split_name}
rm -r $log_dir rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -74,37 +76,37 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
batch_size=1 batch_size=1
split_name=wenetspeech4tts split_name=wenetspeech4tts
backend_type=trt backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type} log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \ torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \ --batch-size $batch_size \
--enable-warmup \ --enable-warmup \
--split-name $split_name \ --split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \ --model-path $ckpt_file \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \ --vocab-file $vocab_file \
--vocoder-trt-engine-path $vocoder_trt_engine_path \ --vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \ --backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1 --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test" echo "Native Pytorch: offline decoding benchmark test"
pip install -r requirements-pytorch.txt if ! python3 -c "import f5_tts" &> /dev/null; then
batch_size=1 pip install -e ../../../../
fi
batch_size=1 # set attn_mask_enabled=True if batched
split_name=wenetspeech4tts split_name=wenetspeech4tts
backend_type=pytorch backend_type=pytorch
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type} log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \ torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \ --batch-size $batch_size \
--split-name $split_name \ --split-name $split_name \
--enable-warmup \ --enable-warmup \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \ --model-path $ckpt_file \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \ --vocab-file $vocab_file \
--backend-type $backend_type \ --backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1 --tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi fi

View File

@@ -172,11 +172,12 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
"--model_name", "--model_name",
type=str, type=str,
default="F5TTS_Base", default="F5TTS_v1_Base",
choices=[ choices=[
"F5TTS_v1_Base",
"F5TTS_Base", "F5TTS_Base",
], ],
) # TODO: support F5TTS_v1_Base )
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt") parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint" "--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
@@ -184,7 +185,6 @@ def parse_arguments():
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT") parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers") parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module") parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size") parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size") parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size") parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
@@ -197,18 +197,29 @@ def parse_arguments():
return args return args
def convert_timm_dit(args, mapping, dtype="float32"): def convert_timm_dit(args, mapping, dtype="float32", use_ema=True):
weights = {} weights = {}
tik = time.time() tik = time.time()
torch_dtype = str_dtype_to_torch(dtype) torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt)) ckpt_path = args.timm_ckpt
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
model_params = load_file(ckpt_path)
else:
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
prefix = "ema_model.transformer." if use_ema else "transformer."
if any(k.startswith(prefix) for k in model_params.keys()):
model_params = { model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer") key[len(prefix) :] if key.startswith(prefix) else key: value
for key, value in model_params.items()
if key.startswith(prefix)
} }
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
@@ -230,7 +241,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
assert len(weights) == len(model_params) assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.' # new_prefix = "f5_transformer."
new_prefix = "" new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()} weights = {new_prefix + key: value for key, value in weights.items()}
import math import math
@@ -278,7 +289,7 @@ def save_config(args):
"num_hidden_layers": 22, "num_hidden_layers": 22,
"num_attention_heads": 16, "num_attention_heads": 16,
"dim_head": 64, "dim_head": 64,
"dropout": 0.1, "dropout": 0.0, # 0.1
"ff_mult": 2, "ff_mult": 2,
"mel_dim": 100, "mel_dim": 100,
"text_num_embeds": 256, "text_num_embeds": 256,
@@ -296,7 +307,7 @@ def save_config(args):
config["quantization"] = { config["quantization"] = {
"quant_algo": "FP8", "quant_algo": "FP8",
# TODO: add support for exclude modules. # TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*", # "exclude_modules": "*final_layer*",
} }
with open(os.path.join(args.output_dir, "config.json"), "w") as f: with open(os.path.join(args.output_dir, "config.json"), "w") as f:

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Manual installation of TensorRT, in case not using NVIDIA NGC:
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
TRTEXEC="/usr/src/tensorrt/bin/trtexec" TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1 ONNX_PATH=$1
@@ -40,4 +42,3 @@ ${TRTEXEC} \
--maxShapes="mel:${MEL_MAX_SHAPE}" \ --maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \ --onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH} --saveEngine=${ENGINE_PATH}