5 Commits

Author SHA1 Message Date
Yushen CHEN
f2a4f8581f Update runtime README 2025-10-22 08:37:32 +08:00
SWivid
a17c5ae435 pytorch imple.: fix batch 1 inference from last commit 2025-10-22 00:31:56 +00:00
SWivid
a0b8fb5df2 runtime trtllm: minor fixes. pytorch: update text_embedding logic to correct v0 batching. 2025-10-22 00:19:45 +00:00
SWivid
c8bfc3aa3d runtime trtllm: support v1 and custom 2025-10-21 22:02:25 +00:00
SWivid
8d3ec72159 runtime trtllm: clean-up v0 code, several fixes. 2025-10-20 10:30:58 +00:00
16 changed files with 477 additions and 582 deletions

View File

@@ -154,8 +154,8 @@ if __name__ == "__main__":
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
ref_text="Some call me nature, others call me mother nature.",
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
@@ -20,7 +21,6 @@ from f5_tts.model.modules import (
ConvPositionEmbedding,
DiTBlock,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -89,8 +89,7 @@ class TextEmbedding(nn.Module):
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
if self.mask_padding:
text_mask = text == 0
@@ -102,10 +101,7 @@ class TextEmbedding(nn.Module):
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
text = text + self.freqs_cis[:seq_len, :]
# convnextv2 blocks
if self.mask_padding:
@@ -241,18 +237,33 @@ class DiT(nn.Module):
cache: bool = True,
audio_mask: bool["b n"] | None = None, # noqa: F722
):
seq_len = x.shape[1]
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
else:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1)
text_embed_list = []
for i in range(batch):
text_embed_i = self.text_embed(
text[i].unsqueeze(0),
seq_lens[i].item(),
drop_text=drop_text,
audio_mask=audio_mask,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
if cache:
if drop_text:
self.text_uncond = text_embed
else:
self.text_cond = text_embed
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)

View File

@@ -252,10 +252,9 @@ class CFM(nn.Module):
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
if not exists(lens): # if lens not acquired by trainer from collate_fn
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
mask = lens_to_mask(lens, length=seq_len)
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)

View File

@@ -0,0 +1,3 @@
# runtime/triton_trtllm related
model.cache
model_repo/

View File

@@ -1,59 +1,68 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
### Setup
#### Option 1: Quick Start
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
# Directly launch the service using docker compose
MODEL=F5TTS_v1_Base docker compose up
```
### Build Image
Build the docker image from scratch.
#### Option 2: Build from scratch
```sh
# Build the docker image
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
# Create Docker Container
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
### Build TensorRT-LLM Engines and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
bash run.sh 0 4 F5TTS_v1_Base
```
> [!NOTE]
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`.
> Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
>
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
> [!IMPORTANT]
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Client-Server Mode
### Benchmarking
#### Using Client-Server Mode
```sh
# bash run.sh 5 5 F5TTS_v1_Base
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark using Offline TRT-LLM Mode
#### Using Offline TRT-LLM Mode
```sh
# bash run.sh 7 7 F5TTS_v1_Base
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
```
### Benchmark Results
@@ -66,4 +75,5 @@ Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pair
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
1. [Yuekai Zhang](https://github.com/yuekaizhang)
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -1,5 +1,5 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,39 +19,45 @@ benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
"""
import argparse
import importlib
import json
import os
import sys
import time
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.eval.utils_eval import padded_mel_batch
from f5_tts.model.modules import get_vocos_mel_spectrogram
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
torch.manual_seed(0)
@@ -111,22 +117,20 @@ def get_args():
return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
(
ids,
ref_rms_list,
ref_mel_list,
ref_mel_len_list,
estimated_reference_target_mel_len,
reference_target_texts_list,
) = (
[],
[],
[],
[],
@@ -148,6 +152,7 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
ref_rms_list.append(ref_rms)
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
@@ -159,40 +164,31 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
ref_audio = ref_audio.to("cuda")
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze()
ref_mel_len = ref_mel.shape[0]
assert ref_mel.shape[1] == 100
ref_mel_len = ref_mel.shape[-1]
assert ref_mel.shape[0] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
max_seq_len = max(estimated_reference_target_mel_len)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_rms_list": ref_rms_list,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
@@ -216,72 +212,6 @@ def init_distributed():
return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
@@ -316,29 +246,11 @@ def load_vocoder(
return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}")
logger.info(f"Loading vocoder engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
@@ -368,34 +280,35 @@ def main():
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
tllm_model_config,
debug_mode=False,
tllm_model_dir=tllm_model_dir,
model_path=args.model_path,
vocab_size=vocab_size,
)
elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
pe_attn_head=1,
text_mask_padding=False,
pretrained_config = tllm_model_config["pretrained_config"]
pt_model_config = dict(
dim=pretrained_config["hidden_size"],
depth=pretrained_config["num_hidden_layers"],
heads=pretrained_config["num_attention_heads"],
ff_mult=pretrained_config["ff_mult"],
text_dim=pretrained_config["text_dim"],
text_mask_padding=pretrained_config["text_mask_padding"],
conv_layers=pretrained_config["conv_layers"],
pe_attn_head=pretrained_config["pe_attn_head"],
)
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
model = load_model(DiT, pt_model_config, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
@@ -445,20 +358,23 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device)
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=16,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
@@ -478,13 +394,13 @@ def main():
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
ref_mels,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
@@ -494,20 +410,20 @@ def main():
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
target_rms = 0.1
target_sample_rate = 24000
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
@@ -519,13 +435,10 @@ def main():
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
if batch["ref_rms_list"][i] < target_rms:
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",

View File

@@ -30,15 +30,6 @@ python3 client_grpc.py \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
@@ -176,8 +167,7 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
help="triton model_repo module name to request",
)
parser.add_argument(
@@ -206,7 +196,7 @@ def get_args():
"--log-dir",
type=str,
required=False,
default="./tmp",
default="./tests/client_grpc",
help="log directory",
)
@@ -230,8 +220,7 @@ def load_audio(wav_path, target_sample_rate=24000):
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate

View File

@@ -24,6 +24,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import os
import numpy as np
import requests
@@ -65,33 +66,32 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
default="tests/client_http.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
samples,
waveform,
reference_text,
target_text,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
waveform = waveform.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
@@ -109,16 +109,15 @@ def prepare_request(
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate
if __name__ == "__main__":
@@ -128,11 +127,11 @@ if __name__ == "__main__":
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
waveform, sr = load_audio(args.reference_audio)
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
waveform = np.array(waveform, dtype=np.float32)
data = prepare_request(waveform, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
@@ -140,4 +139,5 @@ if __name__ == "__main__":
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -12,6 +12,7 @@ import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
from torch.nn.utils.rnn import pad_sequence
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
@@ -32,26 +33,35 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
def forward(self, text, seq_len, drop_text=False):
text = text + 1
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
@@ -112,20 +122,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
def get_text_embed_dict(ckpt_path, use_ema=True):
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model_params = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
for key in model_params.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
return text_embed_dict
@@ -194,18 +217,16 @@ class F5TTS(object):
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
text_num_embeds=vocab_size,
text_dim=config["pretrained_config"]["text_dim"],
mask_padding=config["pretrained_config"]["text_mask_padding"],
conv_layers=config["pretrained_config"]["conv_layers"],
precompute_max_pos=self.max_mel_len,
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
self.head_dim = config["pretrained_config"]["dim_head"]
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
@@ -214,14 +235,23 @@ class F5TTS(object):
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
self.nfe_steps = 32
epss = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
time_step = 1 - torch.cos(torch.pi * t / 2)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
freq_embed_dim = 256 # Warning: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
half_dim = freq_embed_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
@@ -344,7 +374,7 @@ class F5TTS(object):
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
cond_pad_sequence: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
@@ -353,26 +383,43 @@ class F5TTS(object):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
max_seq_len = cond_pad_sequence.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
# get text_embed one by one to avoid misalignment
text_and_drop_embedding_list = []
for i in range(batch):
text_embedding_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=False,
)
text_embedding_drop_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=True,
)
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
# pad separately computed text_embed to form batch with max_seq_len
text_and_drop_embedding = pad_sequence(
text_and_drop_embedding_list,
batch_first=True,
padding_value=0,
)
text_embedding = text_and_drop_embedding[0::2]
text_embedding_drop = text_and_drop_embedding[1::2]
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
noise = torch.randn_like(cond_pad_sequence).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text = torch.cat(
(
cond_pad_sequence,
text_embedding,
),
dim=-1,
)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),

View File

@@ -28,7 +28,6 @@ import os
import jieba
import torch
import torch.nn.functional as F
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
@@ -99,7 +98,8 @@ def list_str_to_idx(
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
class TritonPythonModel:
@@ -107,13 +107,12 @@ class TritonPythonModel:
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.target_rms = 0.1 # least rms when inference, normalize to if lower
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.head_dim = 64
self.max_mel_len = 4096
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
@@ -181,7 +180,8 @@ class TritonPythonModel:
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
reference_rms_list,
) = [], [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
@@ -208,6 +208,7 @@ class TritonPythonModel:
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
reference_rms_list.append(ref_rms)
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
@@ -228,7 +229,7 @@ class TritonPythonModel:
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
@@ -237,15 +238,6 @@ class TritonPythonModel:
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
@@ -262,13 +254,12 @@ class TritonPythonModel:
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
ref_mel_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
if reference_rms_list[i] < self.target_rms:
audio = audio * reference_rms_list[i] / self.target_rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])

View File

@@ -50,6 +50,7 @@ class F5TTS(PretrainedModel):
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
pe_attn_head=config.pe_attn_head,
)
for _ in range(self.depth)
]
@@ -79,13 +80,12 @@ class F5TTS(PretrainedModel):
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mel_size = self.config.mel_dim
max_seq_len = 3000 # 4096
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
concat_feature_dim = mel_size + self.config.text_dim
freq_embed_dim = 256 # Warning: hard coding 256 here
head_dim = self.config.dim_head
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)

View File

@@ -227,29 +227,52 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
full_dim = x.size(-1)
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
if pe_attn_head is None:
pe_attn_head = full_dim // head_dim
rotated_dim = head_dim * pe_attn_head
rotated_and_unrotated_list = []
if default_net().plugin_config.remove_input_padding: # for [N, D] input
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
rotated_and_unrotated_list.append(x_unrotated)
else: # for [B, N, D] input
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat(
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
rotated_and_unrotated_list.append(x_unrotated)
out = concat(rotated_and_unrotated_list, dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __init__(
self,
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
):
self.pe_attn_head = pe_attn_head
def __call__(
self,
@@ -265,8 +288,8 @@ class AttnProcessor:
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
# attention
inner_dim = key.shape[-1]
@@ -354,12 +377,12 @@ class AttnProcessor:
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,

View File

@@ -1,24 +0,0 @@
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
# torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
packaging>=24.2

View File

@@ -1,64 +1,66 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
if [ -z "$model" ]; then
echo "Model is none, using default model F5TTS_Base"
model=F5TTS_Base
model=F5TTS_v1_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
CKPT_DIR=../../../../ckpts
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
MODEL_REPO=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
echo "Downloading F5-TTS from huggingface"
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
fi
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
vocab_file=$CKPT_DIR/$model/vocab.txt
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python3 scripts/convert_checkpoint.py \
--pytorch_ckpt $ckpt_file \
--output_dir $TRTLLM_CKPT_DIR --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
rm -r $MODEL_REPO
cp -r ./model_repo_f5_tts $MODEL_REPO
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
tritonserver --model-repository=$MODEL_REPO
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
split_name=wenetspeech4tts
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -66,7 +68,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
@@ -74,37 +76,37 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
pip install -r requirements-pytorch.txt
batch_size=1
if ! python3 -c "import f5_tts" &> /dev/null; then
pip install -e ../../../../
fi
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi

View File

@@ -23,168 +23,12 @@ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
choices=[
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
@@ -193,33 +37,119 @@ def parse_arguments():
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Custom",
choices=[
"F5TTS_v1_Base",
"F5TTS_Base",
"F5TTS_v1_Small",
"F5TTS_Small",
], # if set, overwrite the below hyperparams
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
parser.add_argument(
"--text_mask_padding",
type=lambda x: x.lower() == "true",
choices=[True, False],
default=True,
help="Whether apply padding mask for conv layers in text encoder",
)
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
args = parser.parse_args()
# overwrite if --model_name ordered
if args.model_name == "F5TTS_v1_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
elif args.model_name == "F5TTS_v1_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
return args
def convert_timm_dit(args, mapping, dtype="float32"):
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
ckpt_path = args.pytorch_ckpt
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
model_params = load_file(ckpt_path)
else:
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
prefix = "ema_model.transformer." if use_ema else "transformer."
if any(k.startswith(prefix) for k in model_params.keys()):
model_params = {
key[len(prefix) :] if key.startswith(prefix) else key: value
for key, value in model_params.items()
if key.startswith(prefix)
}
pytorch_to_trtllm_name = {
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
def get_trtllm_name(pytorch_name):
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
if trtllm_name_if_matched != pytorch_name:
return trtllm_name_if_matched
return pytorch_name
weights = dict()
for name, param in model_params.items():
@@ -230,7 +160,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
# new_prefix = "f5_transformer."
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
@@ -272,19 +202,19 @@ def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"ff_mult": 2,
"hidden_size": args.hidden_size,
"num_hidden_layers": args.depth,
"num_attention_heads": args.num_heads,
"dim_head": args.dim_head,
"dropout": 0.0, # inference-only
"ff_mult": args.ff_mult,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"text_dim": args.text_dim,
"text_mask_padding": args.text_mask_padding,
"conv_layers": args.conv_layers,
"pe_attn_head": args.pe_attn_head,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
@@ -296,7 +226,7 @@ def save_config(args):
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
# "exclude_modules": "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
@@ -315,7 +245,7 @@ def covert_and_save(args, rank):
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
@@ -344,9 +274,9 @@ def main():
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
if args.pytorch_ckpt is None:
return
print("start execute")
print("Start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Manual installation of TensorRT, in case not using NVIDIA NGC:
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
@@ -28,7 +30,7 @@ MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MAX_INPUT_LENGTH=3000 # 4096
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
@@ -40,4 +42,3 @@ ${TRTEXEC} \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}