mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
Compare commits
5 Commits
65ada48a62
...
f2a4f8581f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2a4f8581f | ||
|
|
a17c5ae435 | ||
|
|
a0b8fb5df2 | ||
|
|
c8bfc3aa3d | ||
|
|
8d3ec72159 |
@@ -154,8 +154,8 @@ if __name__ == "__main__":
|
||||
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
ref_text="Some call me nature, others call me mother nature.",
|
||||
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
|
||||
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
@@ -20,7 +21,6 @@ from f5_tts.model.modules import (
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
@@ -89,8 +89,7 @@ class TextEmbedding(nn.Module):
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
@@ -102,10 +101,7 @@ class TextEmbedding(nn.Module):
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
# sinus pos emb
|
||||
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
|
||||
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||
text_pos_embed = self.freqs_cis[pos_idx]
|
||||
text = text + text_pos_embed
|
||||
text = text + self.freqs_cis[:seq_len, :]
|
||||
|
||||
# convnextv2 blocks
|
||||
if self.mask_padding:
|
||||
@@ -241,18 +237,33 @@ class DiT(nn.Module):
|
||||
cache: bool = True,
|
||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||
):
|
||||
seq_len = x.shape[1]
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
if audio_mask is None:
|
||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
|
||||
else:
|
||||
batch = x.shape[0]
|
||||
seq_lens = audio_mask.sum(dim=1)
|
||||
text_embed_list = []
|
||||
for i in range(batch):
|
||||
text_embed_i = self.text_embed(
|
||||
text[i].unsqueeze(0),
|
||||
seq_lens[i].item(),
|
||||
drop_text=drop_text,
|
||||
audio_mask=audio_mask,
|
||||
)
|
||||
text_embed_list.append(text_embed_i[0])
|
||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||
if cache:
|
||||
if drop_text:
|
||||
self.text_uncond = text_embed
|
||||
else:
|
||||
self.text_cond = text_embed
|
||||
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
|
||||
@@ -252,10 +252,9 @@ class CFM(nn.Module):
|
||||
assert text.shape[0] == batch
|
||||
|
||||
# lens and mask
|
||||
if not exists(lens):
|
||||
if not exists(lens): # if lens not acquired by trainer from collate_fn
|
||||
lens = torch.full((batch,), seq_len, device=device)
|
||||
|
||||
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
|
||||
mask = lens_to_mask(lens, length=seq_len)
|
||||
|
||||
# get a random span to mask out for training conditionally
|
||||
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
||||
|
||||
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# runtime/triton_trtllm related
|
||||
model.cache
|
||||
model_repo/
|
||||
@@ -1,59 +1,68 @@
|
||||
## Triton Inference Serving Best Practice for F5-TTS
|
||||
|
||||
### Quick Start
|
||||
Directly launch the service using docker compose.
|
||||
### Setup
|
||||
#### Option 1: Quick Start
|
||||
```sh
|
||||
# TODO: support F5TTS_v1_Base
|
||||
MODEL=F5TTS_Base docker compose up
|
||||
# Directly launch the service using docker compose
|
||||
MODEL=F5TTS_v1_Base docker compose up
|
||||
```
|
||||
|
||||
### Build Image
|
||||
Build the docker image from scratch.
|
||||
#### Option 2: Build from scratch
|
||||
```sh
|
||||
# Build the docker image
|
||||
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Create Docker Container
|
||||
```sh
|
||||
# Create Docker Container
|
||||
your_mount_dir=/mnt:/mnt
|
||||
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Export Models to TensorRT-LLM and Launch Server
|
||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
|
||||
### Build TensorRT-LLM Engines and Launch Server
|
||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
|
||||
```sh
|
||||
bash run.sh 0 4 F5TTS_Base
|
||||
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
|
||||
bash run.sh 0 4 F5TTS_v1_Base
|
||||
```
|
||||
> [!NOTE]
|
||||
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`.
|
||||
> Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
|
||||
>
|
||||
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
|
||||
|
||||
### HTTP Client
|
||||
```sh
|
||||
python3 client_http.py
|
||||
```
|
||||
|
||||
### Benchmark using Client-Server Mode
|
||||
### Benchmarking
|
||||
#### Using Client-Server Mode
|
||||
```sh
|
||||
# bash run.sh 5 5 F5TTS_v1_Base
|
||||
num_task=2
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
||||
```
|
||||
|
||||
### Benchmark using Offline TRT-LLM Mode
|
||||
#### Using Offline TRT-LLM Mode
|
||||
```sh
|
||||
# bash run.sh 7 7 F5TTS_v1_Base
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
```
|
||||
|
||||
### Benchmark Results
|
||||
@@ -66,4 +75,5 @@ Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pair
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
||||
|
||||
### Credits
|
||||
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
1. [Yuekai Zhang](https://github.com/yuekaizhang)
|
||||
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,39 +19,45 @@ benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--model-path $CKPT_DIR/$model/model_1200000.pt \
|
||||
--vocab-file $CKPT_DIR/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import datasets
|
||||
import jieba
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from f5_tts_trtllm import F5TTS
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session, TensorInfo
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from vocos import Vocos
|
||||
|
||||
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||||
|
||||
from f5_tts.eval.utils_eval import padded_mel_batch
|
||||
from f5_tts.model.modules import get_vocos_mel_spectrogram
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
|
||||
|
||||
|
||||
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@@ -111,22 +117,20 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
def padded_mel_batch(ref_mels, max_seq_len):
|
||||
padded_ref_mels = []
|
||||
for mel in ref_mels:
|
||||
# pad along the last dimension
|
||||
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
|
||||
padded_ref_mels.append(padded_ref_mel)
|
||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||
return padded_ref_mels
|
||||
|
||||
|
||||
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("data_collator")
|
||||
target_sample_rate = 24000
|
||||
target_rms = 0.1
|
||||
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
|
||||
(
|
||||
ids,
|
||||
ref_rms_list,
|
||||
ref_mel_list,
|
||||
ref_mel_len_list,
|
||||
estimated_reference_target_mel_len,
|
||||
reference_target_texts_list,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
@@ -148,6 +152,7 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
ref_rms_list.append(ref_rms)
|
||||
if ref_rms < target_rms:
|
||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||
|
||||
@@ -159,40 +164,31 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
||||
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
|
||||
ref_audio = ref_audio.to("cuda")
|
||||
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
ref_mel = ref_mel.squeeze()
|
||||
ref_mel_len = ref_mel.shape[0]
|
||||
assert ref_mel.shape[1] == 100
|
||||
ref_mel_len = ref_mel.shape[-1]
|
||||
assert ref_mel.shape[0] == 100
|
||||
|
||||
ref_mel_list.append(ref_mel)
|
||||
ref_mel_len_list.append(ref_mel_len)
|
||||
|
||||
estimated_reference_target_mel_len.append(
|
||||
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||||
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||||
)
|
||||
|
||||
max_seq_len = max(estimated_reference_target_mel_len)
|
||||
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
|
||||
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
||||
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
||||
|
||||
for i, item in enumerate(text_pad_sequence):
|
||||
text_pad_sequence[i] = F.pad(
|
||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||||
)
|
||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
|
||||
text_pad_sequence = F.pad(
|
||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||||
)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return {
|
||||
"ids": ids,
|
||||
"ref_rms_list": ref_rms_list,
|
||||
"ref_mel_batch": ref_mel_batch,
|
||||
"ref_mel_len_batch": ref_mel_len_batch,
|
||||
"text_pad_sequence": text_pad_sequence,
|
||||
@@ -216,72 +212,6 @@ def init_distributed():
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def get_tokenizer(vocab_file_path: str):
|
||||
"""
|
||||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||||
- "char" for char-wise tokenizer, need .txt vocab_file
|
||||
- "byte" for utf-8 tokenizer
|
||||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||||
- if use "byte", set to 256 (unicode byte range)
|
||||
"""
|
||||
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
||||
vocab_char_map = {}
|
||||
for i, char in enumerate(f):
|
||||
vocab_char_map[char[:-1]] = i
|
||||
vocab_size = len(vocab_char_map)
|
||||
return vocab_char_map, vocab_size
|
||||
|
||||
|
||||
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||||
final_reference_target_texts_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
) # add custom trans here, to address oov
|
||||
|
||||
def is_chinese(c):
|
||||
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
||||
|
||||
for text in reference_target_texts_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
char_list.append(" ")
|
||||
char_list.extend(seg)
|
||||
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
||||
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
||||
for i, c in enumerate(seg):
|
||||
if is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.append(seg_[i])
|
||||
else: # if mixed characters, alphabets and symbols
|
||||
for c in seg:
|
||||
if ord(c) < 256:
|
||||
char_list.extend(c)
|
||||
elif is_chinese(c):
|
||||
char_list.append(" ")
|
||||
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
||||
else:
|
||||
char_list.append(c)
|
||||
final_reference_target_texts_list.append(char_list)
|
||||
|
||||
return final_reference_target_texts_list
|
||||
|
||||
|
||||
def list_str_to_idx(
|
||||
text: Union[List[str], List[List[str]]],
|
||||
vocab_char_map: Dict[str, int], # {char: idx}
|
||||
padding_value=-1,
|
||||
):
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return list_idx_tensors
|
||||
|
||||
|
||||
def load_vocoder(
|
||||
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
||||
):
|
||||
@@ -316,29 +246,11 @@ def load_vocoder(
|
||||
return vocoder
|
||||
|
||||
|
||||
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
|
||||
if vocoder == "vocos":
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=24000,
|
||||
n_fft=1024,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
n_mels=100,
|
||||
power=1,
|
||||
center=True,
|
||||
normalized=False,
|
||||
norm=None,
|
||||
).to(device)
|
||||
mel = mel_stft(waveform.to(device))
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel.transpose(1, 2)
|
||||
|
||||
|
||||
class VocosTensorRT:
|
||||
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
||||
logger.info(f"Loading vae engine from {engine_path}")
|
||||
logger.info(f"Loading vocoder engine from {engine_path}")
|
||||
self.engine_path = engine_path
|
||||
with open(engine_path, "rb") as f:
|
||||
engine_buffer = f.read()
|
||||
@@ -368,34 +280,35 @@ def main():
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
|
||||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
|
||||
|
||||
tllm_model_dir = args.tllm_model_dir
|
||||
config_file = os.path.join(tllm_model_dir, "config.json")
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(tllm_model_dir, "config.json")) as f:
|
||||
tllm_model_config = json.load(f)
|
||||
if args.backend_type == "trt":
|
||||
model = F5TTS(
|
||||
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
|
||||
tllm_model_config,
|
||||
debug_mode=False,
|
||||
tllm_model_dir=tllm_model_dir,
|
||||
model_path=args.model_path,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
import sys
|
||||
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||||
from f5_tts.infer.utils_infer import load_model
|
||||
from f5_tts.model import DiT
|
||||
|
||||
F5TTS_model_cfg = dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
text_mask_padding=False,
|
||||
pretrained_config = tllm_model_config["pretrained_config"]
|
||||
pt_model_config = dict(
|
||||
dim=pretrained_config["hidden_size"],
|
||||
depth=pretrained_config["num_hidden_layers"],
|
||||
heads=pretrained_config["num_attention_heads"],
|
||||
ff_mult=pretrained_config["ff_mult"],
|
||||
text_dim=pretrained_config["text_dim"],
|
||||
text_mask_padding=pretrained_config["text_mask_padding"],
|
||||
conv_layers=pretrained_config["conv_layers"],
|
||||
pe_attn_head=pretrained_config["pe_attn_head"],
|
||||
)
|
||||
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
|
||||
model = load_model(DiT, pt_model_config, args.model_path)
|
||||
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
|
||||
@@ -445,20 +358,23 @@ def main():
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||
if args.backend_type == "trt":
|
||||
_ = model.sample(
|
||||
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
|
||||
text_pad_seq,
|
||||
cond_pad_seq,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
with torch.inference_mode():
|
||||
text_pad_seq -= 1
|
||||
text_pad_seq[text_pad_seq == -2] = -1
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
steps=16,
|
||||
steps=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
@@ -478,13 +394,13 @@ def main():
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
|
||||
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
if args.backend_type == "trt":
|
||||
generated, cost_time = model.sample(
|
||||
text_pad_seq,
|
||||
ref_mels,
|
||||
cond_pad_seq,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
@@ -494,20 +410,20 @@ def main():
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
with torch.inference_mode():
|
||||
start_time = time.time()
|
||||
text_pad_seq -= 1
|
||||
text_pad_seq[text_pad_seq == -2] = -1
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=16,
|
||||
steps=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
cost_time = time.time() - start_time
|
||||
decoding_time += cost_time
|
||||
vocoder_start_time = time.time()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24000
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
@@ -519,13 +435,10 @@ def main():
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24_000
|
||||
# if ref_rms_list[i] < target_rms:
|
||||
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
|
||||
if rms < target_rms:
|
||||
generated_wave = generated_wave * target_rms / rms
|
||||
|
||||
if batch["ref_rms_list"][i] < target_rms:
|
||||
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
|
||||
|
||||
utt = batch["ids"][i]
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/{utt}.wav",
|
||||
|
||||
@@ -30,15 +30,6 @@ python3 client_grpc.py \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name test_zh \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
|
||||
# For offline Spark-TTS-0.5B
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name spark_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name wenetspeech4tts \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -176,8 +167,7 @@ def get_args():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -206,7 +196,7 @@ def get_args():
|
||||
"--log-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default="./tmp",
|
||||
default="./tests/client_grpc",
|
||||
help="log directory",
|
||||
)
|
||||
|
||||
@@ -230,8 +220,7 @@ def load_audio(wav_path, target_sample_rate=24000):
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
||||
waveform = resample(waveform, num_samples)
|
||||
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -65,33 +66,32 @@ def get_args():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-audio",
|
||||
type=str,
|
||||
default="output.wav",
|
||||
default="tests/client_http.wav",
|
||||
help="Path to save the output audio",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
samples,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=24000,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(samples.shape) == 1, "samples should be 1D"
|
||||
lengths = np.array([[len(samples)]], dtype=np.int32)
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
waveform = waveform.reshape(1, -1).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"inputs": [
|
||||
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
|
||||
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
|
||||
{
|
||||
"name": "reference_wav_len",
|
||||
"shape": lengths.shape,
|
||||
@@ -109,16 +109,15 @@ def prepare_request(
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
samples = wav_path["array"]
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
else:
|
||||
samples, sample_rate = sf.read(wav_path)
|
||||
waveform, sample_rate = sf.read(wav_path)
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
|
||||
samples = resample(samples, num_samples)
|
||||
return samples, target_sample_rate
|
||||
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -128,11 +127,11 @@ if __name__ == "__main__":
|
||||
server_url = f"http://{server_url}"
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
samples, sr = load_audio(args.reference_audio)
|
||||
waveform, sr = load_audio(args.reference_audio)
|
||||
assert sr == 24000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(samples, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
waveform = np.array(waveform, dtype=np.float32)
|
||||
data = prepare_request(waveform, args.reference_text, args.target_text)
|
||||
|
||||
rsp = requests.post(
|
||||
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
|
||||
@@ -140,4 +139,5 @@ if __name__ == "__main__":
|
||||
result = rsp.json()
|
||||
audio = result["outputs"][0]["data"]
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
|
||||
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch.nn.functional as F
|
||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
@@ -32,26 +33,35 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
|
||||
def __init__(
|
||||
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
|
||||
):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
self.mask_padding = mask_padding
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||
|
||||
def forward(self, text):
|
||||
# only keep tensors with value not -1
|
||||
text_mask = text != -1
|
||||
text_pad_cut_off_index = text_mask.sum(dim=1).max()
|
||||
def forward(self, text, seq_len, drop_text=False):
|
||||
text = text + 1
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
text = text + self.freqs_cis[:seq_len, :]
|
||||
if self.mask_padding:
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
else:
|
||||
text = self.text_blocks(text)
|
||||
|
||||
text = text[:, :text_pad_cut_off_index]
|
||||
text = self.text_embed(text)
|
||||
text = text + self.freqs_cis[: text.shape[1], :]
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
# padding text to the original length
|
||||
# text shape: B,seq_len,C
|
||||
# pad at the second dimension
|
||||
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
|
||||
return text
|
||||
|
||||
|
||||
@@ -112,20 +122,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_path, use_ema=True):
|
||||
checkpoint = torch.load(ckpt_path, weights_only=True)
|
||||
def get_text_embed_dict(ckpt_path, use_ema=True):
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
checkpoint = load_file(ckpt_path)
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
|
||||
if use_ema:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"ema_model_state_dict": checkpoint}
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
dict_state = checkpoint["model_state_dict"]
|
||||
else:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"model_state_dict": checkpoint}
|
||||
model_params = checkpoint["model_state_dict"]
|
||||
|
||||
text_embed_dict = {}
|
||||
for key in dict_state.keys():
|
||||
for key in model_params.keys():
|
||||
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
||||
if "text_embed" in key:
|
||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
|
||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
|
||||
return text_embed_dict
|
||||
|
||||
|
||||
@@ -194,18 +217,16 @@ class F5TTS(object):
|
||||
|
||||
self.max_mel_len = 4096
|
||||
self.text_embedding = TextEmbedding(
|
||||
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
|
||||
text_num_embeds=vocab_size,
|
||||
text_dim=config["pretrained_config"]["text_dim"],
|
||||
mask_padding=config["pretrained_config"]["text_mask_padding"],
|
||||
conv_layers=config["pretrained_config"]["conv_layers"],
|
||||
precompute_max_pos=self.max_mel_len,
|
||||
).to(self.device)
|
||||
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
||||
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
|
||||
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
# self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
|
||||
self.head_dim = config["pretrained_config"]["dim_head"]
|
||||
self.base_rescale_factor = 1.0
|
||||
self.interpolation_factor = 1.0
|
||||
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
|
||||
@@ -214,14 +235,23 @@ class F5TTS(object):
|
||||
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
||||
self.rope_cos = self.freqs.cos().half()
|
||||
self.rope_sin = self.freqs.sin().half()
|
||||
self.nfe_steps = 16
|
||||
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
|
||||
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
|
||||
|
||||
self.nfe_steps = 32
|
||||
epss = {
|
||||
5: [0, 2, 4, 8, 16, 32],
|
||||
6: [0, 2, 4, 6, 8, 16, 32],
|
||||
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
||||
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
||||
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
}
|
||||
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
|
||||
time_step = 1 - torch.cos(torch.pi * t / 2)
|
||||
delta_t = torch.diff(time_step)
|
||||
# WAR: hard coding 256 here
|
||||
tmp_dim = 256
|
||||
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
||||
half_dim = tmp_dim // 2
|
||||
|
||||
freq_embed_dim = 256 # Warning: hard coding 256 here
|
||||
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
|
||||
half_dim = freq_embed_dim // 2
|
||||
emb_factor = math.log(10000) / (half_dim - 1)
|
||||
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
|
||||
for i in range(self.nfe_steps):
|
||||
@@ -344,7 +374,7 @@ class F5TTS(object):
|
||||
def sample(
|
||||
self,
|
||||
text_pad_sequence: torch.Tensor,
|
||||
ref_mel_batch: torch.Tensor,
|
||||
cond_pad_sequence: torch.Tensor,
|
||||
ref_mel_len_batch: torch.Tensor,
|
||||
estimated_reference_target_mel_len: List[int],
|
||||
remove_input_padding: bool = False,
|
||||
@@ -353,26 +383,43 @@ class F5TTS(object):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("text embedding")
|
||||
batch = text_pad_sequence.shape[0]
|
||||
max_seq_len = ref_mel_batch.shape[1]
|
||||
max_seq_len = cond_pad_sequence.shape[1]
|
||||
|
||||
text_pad_sequence_drop = torch.cat(
|
||||
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
|
||||
# get text_embed one by one to avoid misalignment
|
||||
text_and_drop_embedding_list = []
|
||||
for i in range(batch):
|
||||
text_embedding_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=False,
|
||||
)
|
||||
text_embedding_drop_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=True,
|
||||
)
|
||||
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
|
||||
|
||||
# pad separately computed text_embed to form batch with max_seq_len
|
||||
text_and_drop_embedding = pad_sequence(
|
||||
text_and_drop_embedding_list,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
text_embedding = text_and_drop_embedding[0::2]
|
||||
text_embedding_drop = text_and_drop_embedding[1::2]
|
||||
|
||||
text_embedding_drop_list = []
|
||||
for i in range(batch + 1):
|
||||
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
|
||||
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
||||
|
||||
text_embedding = text_embedding_drop_condition[:-1]
|
||||
# text_embedding_drop B,T,C batch should be the same
|
||||
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
|
||||
|
||||
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
||||
noise = torch.randn_like(cond_pad_sequence).to(self.device)
|
||||
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
|
||||
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
||||
cat_mel_text = torch.cat(
|
||||
(
|
||||
cond_pad_sequence,
|
||||
text_embedding,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
cat_mel_text_drop = torch.cat(
|
||||
(
|
||||
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
|
||||
|
||||
@@ -28,7 +28,6 @@ import os
|
||||
|
||||
import jieba
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from f5_tts_trtllm import F5TTS
|
||||
@@ -99,7 +98,8 @@ def list_str_to_idx(
|
||||
padding_value=-1,
|
||||
): # noqa: F722
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
return list_idx_tensors
|
||||
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return text
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
@@ -107,13 +107,12 @@ class TritonPythonModel:
|
||||
self.use_perf = True
|
||||
self.device = torch.device("cuda")
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.target_rms = 0.1 # least rms when inference, normalize to if lower
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
self.max_mel_len = 4096
|
||||
|
||||
parameters = json.loads(args["model_config"])["parameters"]
|
||||
for key, value in parameters.items():
|
||||
@@ -181,7 +180,8 @@ class TritonPythonModel:
|
||||
reference_target_texts_list,
|
||||
estimated_reference_target_mel_len,
|
||||
reference_mel_len,
|
||||
) = [], [], [], [], []
|
||||
reference_rms_list,
|
||||
) = [], [], [], [], [], []
|
||||
mel_features_list = []
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_push("preprocess")
|
||||
@@ -208,6 +208,7 @@ class TritonPythonModel:
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
|
||||
if ref_rms < self.target_rms:
|
||||
wav = wav * self.target_rms / ref_rms
|
||||
reference_rms_list.append(ref_rms)
|
||||
if self.reference_sample_rate != self.target_audio_sample_rate:
|
||||
wav = self.resampler(wav)
|
||||
wav = wav.to(self.device)
|
||||
@@ -228,7 +229,7 @@ class TritonPythonModel:
|
||||
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
||||
|
||||
batch = len(requests)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
|
||||
for i, mel in enumerate(mel_features_list):
|
||||
mel_features[i, : mel.shape[1], :] = mel
|
||||
|
||||
@@ -237,15 +238,6 @@ class TritonPythonModel:
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
||||
|
||||
for i, item in enumerate(text_pad_sequence):
|
||||
text_pad_sequence[i] = F.pad(
|
||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||||
)
|
||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
|
||||
text_pad_sequence = F.pad(
|
||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||||
)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
@@ -262,13 +254,12 @@ class TritonPythonModel:
|
||||
|
||||
responses = []
|
||||
for i in range(batch):
|
||||
ref_me_len = reference_mel_len[i]
|
||||
ref_mel_len = reference_mel_len[i]
|
||||
estimated_mel_len = estimated_reference_target_mel_len[i]
|
||||
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
audio = self.forward_vocoder(denoised_one_item)
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < self.target_rms:
|
||||
audio = audio * self.target_rms / rms
|
||||
if reference_rms_list[i] < self.target_rms:
|
||||
audio = audio * reference_rms_list[i] / self.target_rms
|
||||
|
||||
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
|
||||
|
||||
@@ -50,6 +50,7 @@ class F5TTS(PretrainedModel):
|
||||
dim_head=config.dim_head,
|
||||
ff_mult=config.ff_mult,
|
||||
dropout=config.dropout,
|
||||
pe_attn_head=config.pe_attn_head,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
@@ -79,13 +80,12 @@ class F5TTS(PretrainedModel):
|
||||
def prepare_inputs(self, **kwargs):
|
||||
max_batch_size = kwargs["max_batch_size"]
|
||||
batch_size_range = [2, 2, max_batch_size]
|
||||
mel_size = 100
|
||||
max_seq_len = 3000
|
||||
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
|
||||
hidden_size = 512
|
||||
concat_feature_dim = mel_size + hidden_size
|
||||
freq_embed_dim = 256
|
||||
head_dim = 64
|
||||
mel_size = self.config.mel_dim
|
||||
max_seq_len = 3000 # 4096
|
||||
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
|
||||
concat_feature_dim = mel_size + self.config.text_dim
|
||||
freq_embed_dim = 256 # Warning: hard coding 256 here
|
||||
head_dim = self.config.dim_head
|
||||
mapping = self.config.mapping
|
||||
if mapping.tp_size > 1:
|
||||
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
|
||||
|
||||
@@ -227,29 +227,52 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
||||
return out
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
rot_dim = shape(rope_cos, -1) # 64
|
||||
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
|
||||
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
|
||||
end_dim = shape(x, -1) - shape(rope_cos, -1)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
else:
|
||||
rot_dim = shape(rope_cos, 2) # 64
|
||||
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
|
||||
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
|
||||
end_dim = shape(x, 2) - shape(rope_cos, 2)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
|
||||
full_dim = x.size(-1)
|
||||
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
|
||||
if pe_attn_head is None:
|
||||
pe_attn_head = full_dim // head_dim
|
||||
rotated_dim = head_dim * pe_attn_head
|
||||
|
||||
rotated_and_unrotated_list = []
|
||||
|
||||
if default_net().plugin_config.remove_input_padding: # for [N, D] input
|
||||
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
|
||||
|
||||
for i in range(pe_attn_head):
|
||||
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
|
||||
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
|
||||
rotated_and_unrotated_list.append(x_rotated_i)
|
||||
|
||||
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
|
||||
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
|
||||
rotated_and_unrotated_list.append(x_unrotated)
|
||||
|
||||
else: # for [B, N, D] input
|
||||
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
|
||||
|
||||
for i in range(pe_attn_head):
|
||||
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
|
||||
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
|
||||
rotated_and_unrotated_list.append(x_rotated_i)
|
||||
|
||||
new_t_unrotated_shape = concat(
|
||||
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
|
||||
) # (2, -1, 1024 - 64 * pe_attn_head)
|
||||
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
|
||||
rotated_and_unrotated_list.append(x_unrotated)
|
||||
|
||||
out = concat(rotated_and_unrotated_list, dim=-1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
|
||||
):
|
||||
self.pe_attn_head = pe_attn_head
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -265,8 +288,8 @@ class AttnProcessor:
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
# k,v,q all (2,1226,1024)
|
||||
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
|
||||
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
|
||||
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
|
||||
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
@@ -354,12 +377,12 @@ class AttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
class DiTBlock(Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
accelerate>=0.33.0
|
||||
bitsandbytes>0.37.0
|
||||
cached_path
|
||||
click
|
||||
datasets
|
||||
ema_pytorch>=0.5.2
|
||||
gradio>=3.45.2
|
||||
hydra-core>=1.3.0
|
||||
jieba
|
||||
librosa
|
||||
matplotlib
|
||||
numpy<=1.26.4
|
||||
pydub
|
||||
pypinyin
|
||||
safetensors
|
||||
soundfile
|
||||
tomli
|
||||
torch>=2.0.0
|
||||
# torchaudio>=2.0.0
|
||||
torchdiffeq
|
||||
tqdm>=4.65.0
|
||||
transformers
|
||||
x_transformers>=1.31.14
|
||||
packaging>=24.2
|
||||
@@ -1,64 +1,66 @@
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
model=$3 # F5TTS_Base
|
||||
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
|
||||
if [ -z "$model" ]; then
|
||||
echo "Model is none, using default model F5TTS_Base"
|
||||
model=F5TTS_Base
|
||||
model=F5TTS_v1_Base
|
||||
fi
|
||||
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
|
||||
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
|
||||
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
|
||||
CKPT_DIR=../../../../ckpts
|
||||
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
|
||||
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
|
||||
|
||||
vocoder_trt_engine_path=vocos_vocoder.plan
|
||||
model_repo=./model_repo
|
||||
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
|
||||
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
|
||||
MODEL_REPO=./model_repo
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading f5 tts from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
||||
|
||||
echo "Downloading F5-TTS from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
|
||||
fi
|
||||
|
||||
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
|
||||
vocab_file=$CKPT_DIR/$model/vocab.txt
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Converting checkpoint"
|
||||
python3 ./scripts/convert_checkpoint.py \
|
||||
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
|
||||
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
|
||||
python3 scripts/convert_checkpoint.py \
|
||||
--pytorch_ckpt $ckpt_file \
|
||||
--output_dir $TRTLLM_CKPT_DIR --model_name $model
|
||||
python_package_path=/usr/local/lib/python3.12/dist-packages
|
||||
cp -r patch/* $python_package_path/tensorrt_llm/models
|
||||
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
|
||||
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
|
||||
--max_batch_size 8 \
|
||||
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
|
||||
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
echo "Exporting vocos vocoder"
|
||||
onnx_vocoder_path=vocos_vocoder.onnx
|
||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
|
||||
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
|
||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
|
||||
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
echo "Building triton server"
|
||||
rm -r $model_repo
|
||||
cp -r ./model_repo_f5_tts $model_repo
|
||||
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
|
||||
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
|
||||
rm -r $MODEL_REPO
|
||||
cp -r ./model_repo_f5_tts $MODEL_REPO
|
||||
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
|
||||
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
echo "Starting triton server"
|
||||
tritonserver --model-repository=$model_repo
|
||||
tritonserver --model-repository=$MODEL_REPO
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Testing triton server"
|
||||
num_task=1
|
||||
log_dir=./log_concurrent_tasks_${num_task}
|
||||
split_name=wenetspeech4tts
|
||||
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
|
||||
rm -r $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
@@ -66,7 +68,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
audio=../../infer/examples/basic/basic_ref_en.wav
|
||||
reference_text="Some call me nature, others call me mother nature."
|
||||
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
@@ -74,37 +76,37 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
echo "Native Pytorch: offline decoding benchmark test"
|
||||
pip install -r requirements-pytorch.txt
|
||||
batch_size=1
|
||||
if ! python3 -c "import f5_tts" &> /dev/null; then
|
||||
pip install -e ../../../../
|
||||
fi
|
||||
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=pytorch
|
||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--split-name $split_name \
|
||||
--enable-warmup \
|
||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
fi
|
||||
@@ -23,168 +23,12 @@ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
||||
return split_v.contiguous()
|
||||
|
||||
|
||||
FACEBOOK_DIT_NAME_MAPPING = {
|
||||
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
|
||||
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
|
||||
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
|
||||
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
|
||||
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
|
||||
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
|
||||
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
|
||||
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
|
||||
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
|
||||
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
|
||||
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
|
||||
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
|
||||
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
|
||||
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
|
||||
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
|
||||
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
|
||||
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
|
||||
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
|
||||
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
|
||||
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
|
||||
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
|
||||
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
|
||||
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
|
||||
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
|
||||
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
|
||||
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
|
||||
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
|
||||
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
|
||||
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
|
||||
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
|
||||
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
|
||||
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
|
||||
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
|
||||
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
|
||||
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
|
||||
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
|
||||
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
|
||||
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
|
||||
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
|
||||
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
|
||||
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
|
||||
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
|
||||
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
|
||||
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
|
||||
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
|
||||
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
|
||||
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
|
||||
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
|
||||
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
|
||||
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
|
||||
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
|
||||
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
|
||||
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
|
||||
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
|
||||
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
|
||||
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
|
||||
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
|
||||
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
|
||||
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
|
||||
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
|
||||
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
|
||||
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
|
||||
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
|
||||
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
|
||||
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
|
||||
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
|
||||
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
|
||||
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
|
||||
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
|
||||
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
|
||||
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
|
||||
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
|
||||
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
|
||||
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
|
||||
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
|
||||
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
|
||||
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
|
||||
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
|
||||
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
|
||||
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
|
||||
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
|
||||
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
|
||||
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
|
||||
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
|
||||
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
|
||||
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
|
||||
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
|
||||
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
|
||||
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
|
||||
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
|
||||
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
|
||||
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
|
||||
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
|
||||
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
|
||||
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
|
||||
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
|
||||
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
|
||||
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
|
||||
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
|
||||
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
|
||||
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
|
||||
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
|
||||
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
|
||||
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
|
||||
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
|
||||
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
|
||||
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
|
||||
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
|
||||
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
|
||||
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
|
||||
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
|
||||
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
|
||||
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
|
||||
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
|
||||
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
|
||||
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
|
||||
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
|
||||
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
|
||||
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
|
||||
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
|
||||
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
|
||||
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
|
||||
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
|
||||
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
|
||||
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
|
||||
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
|
||||
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
|
||||
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
|
||||
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
|
||||
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
|
||||
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
|
||||
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
|
||||
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
|
||||
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
|
||||
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
|
||||
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
|
||||
}
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="F5TTS_Base",
|
||||
choices=[
|
||||
"F5TTS_Base",
|
||||
],
|
||||
) # TODO: support F5TTS_v1_Base
|
||||
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
||||
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
|
||||
)
|
||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
||||
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
||||
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
||||
@@ -193,33 +37,119 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="F5TTS_Custom",
|
||||
choices=[
|
||||
"F5TTS_v1_Base",
|
||||
"F5TTS_Base",
|
||||
"F5TTS_v1_Small",
|
||||
"F5TTS_Small",
|
||||
], # if set, overwrite the below hyperparams
|
||||
)
|
||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
|
||||
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
|
||||
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
|
||||
parser.add_argument(
|
||||
"--text_mask_padding",
|
||||
type=lambda x: x.lower() == "true",
|
||||
choices=[True, False],
|
||||
default=True,
|
||||
help="Whether apply padding mask for conv layers in text encoder",
|
||||
)
|
||||
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
|
||||
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
|
||||
args = parser.parse_args()
|
||||
|
||||
# overwrite if --model_name ordered
|
||||
if args.model_name == "F5TTS_v1_Base":
|
||||
args.hidden_size = 1024
|
||||
args.depth = 22
|
||||
args.num_heads = 16
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = True
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = None
|
||||
elif args.model_name == "F5TTS_Base":
|
||||
args.hidden_size = 1024
|
||||
args.depth = 22
|
||||
args.num_heads = 16
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = False
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = 1
|
||||
elif args.model_name == "F5TTS_v1_Small":
|
||||
args.hidden_size = 768
|
||||
args.depth = 18
|
||||
args.num_heads = 12
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = True
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = None
|
||||
elif args.model_name == "F5TTS_Small":
|
||||
args.hidden_size = 768
|
||||
args.depth = 18
|
||||
args.num_heads = 12
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = False
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = 1
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def convert_timm_dit(args, mapping, dtype="float32"):
|
||||
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
|
||||
weights = {}
|
||||
tik = time.time()
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
tensor_parallel = mapping.tp_size
|
||||
|
||||
model_params = dict(torch.load(args.timm_ckpt))
|
||||
model_params = {
|
||||
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
|
||||
ckpt_path = args.pytorch_ckpt
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
model_params = load_file(ckpt_path)
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
|
||||
|
||||
prefix = "ema_model.transformer." if use_ema else "transformer."
|
||||
if any(k.startswith(prefix) for k in model_params.keys()):
|
||||
model_params = {
|
||||
key[len(prefix) :] if key.startswith(prefix) else key: value
|
||||
for key, value in model_params.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
|
||||
pytorch_to_trtllm_name = {
|
||||
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
|
||||
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
|
||||
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
|
||||
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
|
||||
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
|
||||
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
|
||||
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
|
||||
}
|
||||
prefix = "ema_model.transformer."
|
||||
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
|
||||
|
||||
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
||||
|
||||
def get_trtllm_name(timm_name):
|
||||
for k, v in timm_to_trtllm_name.items():
|
||||
m = re.match(k, timm_name)
|
||||
if m is not None:
|
||||
if "*" in v:
|
||||
v = v.replace("*", m.groups()[0])
|
||||
return v
|
||||
return timm_name
|
||||
def get_trtllm_name(pytorch_name):
|
||||
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
|
||||
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
|
||||
if trtllm_name_if_matched != pytorch_name:
|
||||
return trtllm_name_if_matched
|
||||
return pytorch_name
|
||||
|
||||
weights = dict()
|
||||
for name, param in model_params.items():
|
||||
@@ -230,7 +160,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
|
||||
|
||||
assert len(weights) == len(model_params)
|
||||
|
||||
# new_prefix = 'f5_transformer.'
|
||||
# new_prefix = "f5_transformer."
|
||||
new_prefix = ""
|
||||
weights = {new_prefix + key: value for key, value in weights.items()}
|
||||
import math
|
||||
@@ -272,19 +202,19 @@ def save_config(args):
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
config = {
|
||||
"architecture": "F5TTS",
|
||||
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
|
||||
"dtype": args.dtype,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 22,
|
||||
"num_attention_heads": 16,
|
||||
"dim_head": 64,
|
||||
"dropout": 0.1,
|
||||
"ff_mult": 2,
|
||||
"hidden_size": args.hidden_size,
|
||||
"num_hidden_layers": args.depth,
|
||||
"num_attention_heads": args.num_heads,
|
||||
"dim_head": args.dim_head,
|
||||
"dropout": 0.0, # inference-only
|
||||
"ff_mult": args.ff_mult,
|
||||
"mel_dim": 100,
|
||||
"text_num_embeds": 256,
|
||||
"text_dim": 512,
|
||||
"conv_layers": 4,
|
||||
"long_skip_connection": False,
|
||||
"text_dim": args.text_dim,
|
||||
"text_mask_padding": args.text_mask_padding,
|
||||
"conv_layers": args.conv_layers,
|
||||
"pe_attn_head": args.pe_attn_head,
|
||||
"mapping": {
|
||||
"world_size": args.cp_size * args.tp_size * args.pp_size,
|
||||
"cp_size": args.cp_size,
|
||||
@@ -296,7 +226,7 @@ def save_config(args):
|
||||
config["quantization"] = {
|
||||
"quant_algo": "FP8",
|
||||
# TODO: add support for exclude modules.
|
||||
# 'exclude_modules': "*final_layer*",
|
||||
# "exclude_modules": "*final_layer*",
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
||||
@@ -315,7 +245,7 @@ def covert_and_save(args, rank):
|
||||
pp_size=args.pp_size,
|
||||
)
|
||||
|
||||
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
||||
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
|
||||
|
||||
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
|
||||
|
||||
@@ -344,9 +274,9 @@ def main():
|
||||
assert args.pp_size == 1, "PP is not supported yet."
|
||||
|
||||
tik = time.time()
|
||||
if args.timm_ckpt is None:
|
||||
if args.pytorch_ckpt is None:
|
||||
return
|
||||
print("start execute")
|
||||
print("Start execute")
|
||||
execute(args.workers, [covert_and_save] * world_size, args)
|
||||
|
||||
tok = time.time()
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Manual installation of TensorRT, in case not using NVIDIA NGC:
|
||||
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
|
||||
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
|
||||
|
||||
ONNX_PATH=$1
|
||||
@@ -28,7 +30,7 @@ MAX_BATCH_SIZE=8
|
||||
|
||||
MIN_INPUT_LENGTH=1
|
||||
OPT_INPUT_LENGTH=1000
|
||||
MAX_INPUT_LENGTH=3000
|
||||
MAX_INPUT_LENGTH=3000 # 4096
|
||||
|
||||
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
|
||||
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
|
||||
@@ -40,4 +42,3 @@ ${TRTEXEC} \
|
||||
--maxShapes="mel:${MEL_MAX_SHAPE}" \
|
||||
--onnx=${ONNX_PATH} \
|
||||
--saveEngine=${ENGINE_PATH}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user