runtime trtllm: support v1 and custom

This commit is contained in:
SWivid
2025-10-21 22:02:25 +00:00
parent 8d3ec72159
commit c8bfc3aa3d
12 changed files with 237 additions and 281 deletions

View File

@@ -154,8 +154,8 @@ if __name__ == "__main__":
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
ref_text="Some call me nature, others call me mother nature.",
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,

View File

@@ -1,58 +1,64 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
### Setup
#### Option 1: Quick Start
```sh
# Directly launch the service using docker compose
MODEL=F5TTS_v1_Base docker compose up
```
### Build Image
Build the docker image from scratch.
#### Option 2: Build from scratch
```sh
# Build the docker image
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
# Create Docker Container
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
### Build TensorRT-LLM Engines and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
```sh
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
bash run.sh 0 4 F5TTS_v1_Base
```
> [!NOTE]
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`. Remember to used matched model version (`F5TTS_v1_Base` for v1, `F5TTS_Base` for v0).
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`. Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
> [!IMPORTANT]
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Client-Server Mode
### Benchmarking
#### Using Client-Server Mode
```sh
# bash run.sh 5 5 F5TTS_v1_Base
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark using Offline TRT-LLM Mode
#### Using Offline TRT-LLM Mode
```sh
# bash run.sh 7 7 F5TTS_v1_Base
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
```

View File

@@ -282,10 +282,10 @@ def main():
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
tllm_model_dir = args.tllm_model_dir
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
if args.backend_type == "trt":
tllm_model_dir = args.tllm_model_dir
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
model = F5TTS(
tllm_model_config,
debug_mode=False,
@@ -297,17 +297,18 @@ def main():
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
pe_attn_head=1,
text_mask_padding=False,
pretrained_config = tllm_model_config["pretrained_config"]
pt_model_config = dict(
dim=pretrained_config["hidden_size"],
depth=pretrained_config["num_hidden_layers"],
heads=pretrained_config["num_attention_heads"],
ff_mult=pretrained_config["ff_mult"],
text_dim=pretrained_config["text_dim"],
text_mask_padding=pretrained_config["text_mask_padding"],
conv_layers=pretrained_config["conv_layers"],
pe_attn_head=pretrained_config["pe_attn_head"],
)
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
model = load_model(DiT, pt_model_config, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path

View File

@@ -220,8 +220,7 @@ def load_audio(wav_path, target_sample_rate=24000):
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate

View File

@@ -79,19 +79,19 @@ def get_args():
def prepare_request(
samples,
waveform,
reference_text,
target_text,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
waveform = waveform.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
@@ -109,16 +109,15 @@ def prepare_request(
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate
if __name__ == "__main__":
@@ -128,11 +127,11 @@ if __name__ == "__main__":
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
waveform, sr = load_audio(args.reference_audio)
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
waveform = np.array(waveform, dtype=np.float32)
data = prepare_request(waveform, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}

View File

@@ -33,9 +33,12 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
@@ -43,10 +46,18 @@ class TextEmbedding(nn.Module):
text = text + 1
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
if self.mask_padding:
text_mask = text == 0
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
text = self.text_blocks(text)
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
@@ -203,17 +214,16 @@ class F5TTS(object):
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
text_num_embeds=vocab_size,
text_dim=config["pretrained_config"]["text_dim"],
mask_padding=config["pretrained_config"]["text_mask_padding"],
conv_layers=config["pretrained_config"]["conv_layers"],
precompute_max_pos=self.max_mel_len,
).to(self.device)
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
# self.target_audio_sample_rate = 24000
# self.target_rms = 0.1 # least rms when inference, normalize to if lower
# self.n_fft = 1024
# self.win_length = 1024
# self.hop_length = 256
self.n_mel_channels = 100
self.head_dim = 64
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
self.head_dim = config["pretrained_config"]["dim_head"]
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
@@ -236,9 +246,9 @@ class F5TTS(object):
time_step = 1 - torch.cos(torch.pi * t / 2)
delta_t = torch.diff(time_step)
tmp_dim = 256 # WAR: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
freq_embed_dim = 256 # Warning: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
half_dim = freq_embed_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):

View File

@@ -113,7 +113,6 @@ class TritonPythonModel:
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 4096
self.head_dim = 64
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():

View File

@@ -50,6 +50,7 @@ class F5TTS(PretrainedModel):
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
pe_attn_head=config.pe_attn_head,
)
for _ in range(self.depth)
]
@@ -79,13 +80,12 @@ class F5TTS(PretrainedModel):
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 4096
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mel_size = self.config.mel_dim
max_seq_len = 3000 # 4096
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
concat_feature_dim = mel_size + self.config.text_dim
freq_embed_dim = 256 # Warning: hard coding 256 here
head_dim = self.config.dim_head
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)

View File

@@ -227,29 +227,52 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
full_dim = x.size(-1)
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
if pe_attn_head is None:
pe_attn_head = full_dim // head_dim
rotated_dim = head_dim * pe_attn_head
rotated_and_unrotated_list = []
if default_net().plugin_config.remove_input_padding: # for [N, D] input
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
rotated_and_unrotated_list.append(x_unrotated)
else: # for [B, N, D] input
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat(
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
rotated_and_unrotated_list.append(x_unrotated)
out = concat(rotated_and_unrotated_list, dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __init__(
self,
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
):
self.pe_attn_head = pe_attn_head
def __call__(
self,
@@ -265,8 +288,8 @@ class AttnProcessor:
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
# attention
inner_dim = key.shape[-1]
@@ -354,12 +377,12 @@ class AttnProcessor:
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,

View File

@@ -1,6 +1,6 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_v1_Base | F5TTS_Base
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
if [ -z "$model" ]; then
model=F5TTS_v1_Base
fi
@@ -26,7 +26,7 @@ vocab_file=$CKPT_DIR/$model/vocab.txt
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 scripts/convert_checkpoint.py \
--timm_ckpt $ckpt_file \
--pytorch_ckpt $ckpt_file \
--output_dir $TRTLLM_CKPT_DIR --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
@@ -58,7 +58,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
split_name=wenetspeech4tts
log_dir=./tests/client_grpc_concurrent_${num_task}_${split_name}
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
fi
@@ -68,7 +68,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
@@ -76,7 +76,7 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
@@ -95,10 +95,10 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
if ! python3 -c "import f5_tts" &> /dev/null; then
pip install -e ../../../../
fi
batch_size=1 # set attn_mask_enabled=True if batched
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
@@ -109,4 +109,4 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
--vocab-file $vocab_file \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi
fi

View File

@@ -23,168 +23,12 @@ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_v1_Base",
choices=[
"F5TTS_v1_Base",
"F5TTS_Base",
],
)
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
@@ -193,17 +37,86 @@ def parse_arguments():
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Custom",
choices=[
"F5TTS_v1_Base",
"F5TTS_Base",
"F5TTS_v1_Small",
"F5TTS_Small",
], # if set, overwrite the below hyperparams
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
parser.add_argument(
"--text_mask_padding",
type=lambda x: x.lower() == "true",
choices=[True, False],
default=True,
help="Whether apply padding mask for conv layers in text encoder",
)
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
args = parser.parse_args()
# overwrite if --model_name ordered
if args.model_name == "F5TTS_v1_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
elif args.model_name == "F5TTS_v1_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
return args
def convert_timm_dit(args, mapping, dtype="float32", use_ema=True):
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
ckpt_path = args.timm_ckpt
ckpt_path = args.pytorch_ckpt
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
@@ -221,16 +134,22 @@ def convert_timm_dit(args, mapping, dtype="float32", use_ema=True):
if key.startswith(prefix)
}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
pytorch_to_trtllm_name = {
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
}
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
def get_trtllm_name(pytorch_name):
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
if trtllm_name_if_matched != pytorch_name:
return trtllm_name_if_matched
return pytorch_name
weights = dict()
for name, param in model_params.items():
@@ -283,19 +202,19 @@ def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.0, # 0.1
"ff_mult": 2,
"hidden_size": args.hidden_size,
"num_hidden_layers": args.depth,
"num_attention_heads": args.num_heads,
"dim_head": args.dim_head,
"dropout": 0.0, # inference-only
"ff_mult": args.ff_mult,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"text_dim": args.text_dim,
"text_mask_padding": args.text_mask_padding,
"conv_layers": args.conv_layers,
"pe_attn_head": args.pe_attn_head,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
@@ -326,7 +245,7 @@ def covert_and_save(args, rank):
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
@@ -355,9 +274,9 @@ def main():
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
if args.pytorch_ckpt is None:
return
print("start execute")
print("Start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()

View File

@@ -30,7 +30,7 @@ MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MAX_INPUT_LENGTH=3000 # 4096
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"