mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
runtime trtllm: clean-up v0 code, several fixes.
This commit is contained in:
@@ -20,7 +20,6 @@ from f5_tts.model.modules import (
|
|||||||
ConvPositionEmbedding,
|
ConvPositionEmbedding,
|
||||||
DiTBlock,
|
DiTBlock,
|
||||||
TimestepEmbedding,
|
TimestepEmbedding,
|
||||||
get_pos_embed_indices,
|
|
||||||
precompute_freqs_cis,
|
precompute_freqs_cis,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,8 +88,7 @@ class TextEmbedding(nn.Module):
|
|||||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
|
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
|
||||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||||
batch, text_len = text.shape[0], text.shape[1]
|
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||||
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
|
|
||||||
if self.mask_padding:
|
if self.mask_padding:
|
||||||
text_mask = text == 0
|
text_mask = text == 0
|
||||||
|
|
||||||
@@ -102,10 +100,7 @@ class TextEmbedding(nn.Module):
|
|||||||
# possible extra modeling
|
# possible extra modeling
|
||||||
if self.extra_modeling:
|
if self.extra_modeling:
|
||||||
# sinus pos emb
|
# sinus pos emb
|
||||||
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
|
text = text + self.freqs_cis[:seq_len, :]
|
||||||
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
|
||||||
text_pos_embed = self.freqs_cis[pos_idx]
|
|
||||||
text = text + text_pos_embed
|
|
||||||
|
|
||||||
# convnextv2 blocks
|
# convnextv2 blocks
|
||||||
if self.mask_padding:
|
if self.mask_padding:
|
||||||
@@ -242,6 +237,7 @@ class DiT(nn.Module):
|
|||||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||||
):
|
):
|
||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
|
# TODO. modify to get text_embed one by one (to avoid misalignment when batching), as done in runtime imple.
|
||||||
if cache:
|
if cache:
|
||||||
if drop_text:
|
if drop_text:
|
||||||
if self.text_uncond is None:
|
if self.text_uncond is None:
|
||||||
|
|||||||
@@ -252,10 +252,9 @@ class CFM(nn.Module):
|
|||||||
assert text.shape[0] == batch
|
assert text.shape[0] == batch
|
||||||
|
|
||||||
# lens and mask
|
# lens and mask
|
||||||
if not exists(lens):
|
if not exists(lens): # if lens not acquired by trainer from collate_fn
|
||||||
lens = torch.full((batch,), seq_len, device=device)
|
lens = torch.full((batch,), seq_len, device=device)
|
||||||
|
mask = lens_to_mask(lens, length=seq_len)
|
||||||
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
|
|
||||||
|
|
||||||
# get a random span to mask out for training conditionally
|
# get a random span to mask out for training conditionally
|
||||||
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
||||||
|
|||||||
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# runtime/triton_trtllm related
|
||||||
|
model.cache
|
||||||
|
model_repo/
|
||||||
@@ -3,8 +3,7 @@
|
|||||||
### Quick Start
|
### Quick Start
|
||||||
Directly launch the service using docker compose.
|
Directly launch the service using docker compose.
|
||||||
```sh
|
```sh
|
||||||
# TODO: support F5TTS_v1_Base
|
MODEL=F5TTS_v1_Base docker compose up
|
||||||
MODEL=F5TTS_Base docker compose up
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Build Image
|
### Build Image
|
||||||
@@ -20,10 +19,12 @@ docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Export Models to TensorRT-LLM and Launch Server
|
### Export Models to TensorRT-LLM and Launch Server
|
||||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
|
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
|
||||||
```sh
|
```sh
|
||||||
bash run.sh 0 4 F5TTS_Base
|
bash run.sh 0 4 F5TTS_v1_Base
|
||||||
```
|
```
|
||||||
|
> [!NOTE]
|
||||||
|
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`. Remember to used matched model version (`F5TTS_v1_Base` for v1, `F5TTS_Base` for v0).
|
||||||
|
|
||||||
### HTTP Client
|
### HTTP Client
|
||||||
```sh
|
```sh
|
||||||
@@ -49,11 +50,11 @@ benchmark.py --output-dir $log_dir \
|
|||||||
--batch-size $batch_size \
|
--batch-size $batch_size \
|
||||||
--enable-warmup \
|
--enable-warmup \
|
||||||
--split-name $split_name \
|
--split-name $split_name \
|
||||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
--model-path $CKPT_DIR/$model/model_1200000.pt \
|
||||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
--vocab-file $CKPT_DIR/$model/vocab.txt \
|
||||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||||
--backend-type $backend_type \
|
--backend-type $backend_type \
|
||||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||||
```
|
```
|
||||||
|
|
||||||
### Benchmark Results
|
### Benchmark Results
|
||||||
@@ -66,4 +67,5 @@ Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pair
|
|||||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
||||||
|
|
||||||
### Credits
|
### Credits
|
||||||
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
1. [Yuekai Zhang](https://github.com/yuekaizhang)
|
||||||
|
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||||
# 2025 (authors: Yuekai Zhang)
|
# 2025 (authors: Yuekai Zhang)
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -19,39 +19,45 @@ benchmark.py --output-dir $log_dir \
|
|||||||
--batch-size $batch_size \
|
--batch-size $batch_size \
|
||||||
--enable-warmup \
|
--enable-warmup \
|
||||||
--split-name $split_name \
|
--split-name $split_name \
|
||||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
--model-path $CKPT_DIR/$model/model_1200000.pt \
|
||||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
--vocab-file $CKPT_DIR/$model/vocab.txt \
|
||||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||||
--backend-type $backend_type \
|
--backend-type $backend_type \
|
||||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Union
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import jieba
|
|
||||||
import tensorrt as trt
|
import tensorrt as trt
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from f5_tts_trtllm import F5TTS
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from pypinyin import Style, lazy_pinyin
|
|
||||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.runtime.session import Session, TensorInfo
|
from tensorrt_llm.runtime.session import Session, TensorInfo
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from vocos import Vocos
|
from vocos import Vocos
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||||||
|
|
||||||
|
from f5_tts.eval.utils_eval import padded_mel_batch
|
||||||
|
from f5_tts.model.modules import get_vocos_mel_spectrogram
|
||||||
|
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
@@ -111,22 +117,20 @@ def get_args():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def padded_mel_batch(ref_mels, max_seq_len):
|
|
||||||
padded_ref_mels = []
|
|
||||||
for mel in ref_mels:
|
|
||||||
# pad along the last dimension
|
|
||||||
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
|
|
||||||
padded_ref_mels.append(padded_ref_mel)
|
|
||||||
padded_ref_mels = torch.stack(padded_ref_mels)
|
|
||||||
return padded_ref_mels
|
|
||||||
|
|
||||||
|
|
||||||
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||||
if use_perf:
|
if use_perf:
|
||||||
torch.cuda.nvtx.range_push("data_collator")
|
torch.cuda.nvtx.range_push("data_collator")
|
||||||
target_sample_rate = 24000
|
target_sample_rate = 24000
|
||||||
target_rms = 0.1
|
target_rms = 0.1
|
||||||
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
|
(
|
||||||
|
ids,
|
||||||
|
ref_rms_list,
|
||||||
|
ref_mel_list,
|
||||||
|
ref_mel_len_list,
|
||||||
|
estimated_reference_target_mel_len,
|
||||||
|
reference_target_texts_list,
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
@@ -148,6 +152,7 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
|||||||
)
|
)
|
||||||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||||
|
ref_rms_list.append(ref_rms)
|
||||||
if ref_rms < target_rms:
|
if ref_rms < target_rms:
|
||||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||||
|
|
||||||
@@ -159,40 +164,31 @@ def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
|||||||
|
|
||||||
if use_perf:
|
if use_perf:
|
||||||
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
||||||
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
|
ref_audio = ref_audio.to("cuda")
|
||||||
|
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
|
||||||
if use_perf:
|
if use_perf:
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
ref_mel = ref_mel.squeeze()
|
ref_mel_len = ref_mel.shape[-1]
|
||||||
ref_mel_len = ref_mel.shape[0]
|
assert ref_mel.shape[0] == 100
|
||||||
assert ref_mel.shape[1] == 100
|
|
||||||
|
|
||||||
ref_mel_list.append(ref_mel)
|
ref_mel_list.append(ref_mel)
|
||||||
ref_mel_len_list.append(ref_mel_len)
|
ref_mel_len_list.append(ref_mel_len)
|
||||||
|
|
||||||
estimated_reference_target_mel_len.append(
|
estimated_reference_target_mel_len.append(
|
||||||
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_len = max(estimated_reference_target_mel_len)
|
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
||||||
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
|
|
||||||
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||||
|
|
||||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||||
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
||||||
|
|
||||||
for i, item in enumerate(text_pad_sequence):
|
|
||||||
text_pad_sequence[i] = F.pad(
|
|
||||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
|
||||||
)
|
|
||||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
|
||||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
|
|
||||||
text_pad_sequence = F.pad(
|
|
||||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
|
||||||
)
|
|
||||||
if use_perf:
|
if use_perf:
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
return {
|
return {
|
||||||
"ids": ids,
|
"ids": ids,
|
||||||
|
"ref_rms_list": ref_rms_list,
|
||||||
"ref_mel_batch": ref_mel_batch,
|
"ref_mel_batch": ref_mel_batch,
|
||||||
"ref_mel_len_batch": ref_mel_len_batch,
|
"ref_mel_len_batch": ref_mel_len_batch,
|
||||||
"text_pad_sequence": text_pad_sequence,
|
"text_pad_sequence": text_pad_sequence,
|
||||||
@@ -216,72 +212,6 @@ def init_distributed():
|
|||||||
return world_size, local_rank, rank
|
return world_size, local_rank, rank
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(vocab_file_path: str):
|
|
||||||
"""
|
|
||||||
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
|
||||||
- "char" for char-wise tokenizer, need .txt vocab_file
|
|
||||||
- "byte" for utf-8 tokenizer
|
|
||||||
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
|
||||||
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
|
||||||
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
|
||||||
- if use "byte", set to 256 (unicode byte range)
|
|
||||||
"""
|
|
||||||
with open(vocab_file_path, "r", encoding="utf-8") as f:
|
|
||||||
vocab_char_map = {}
|
|
||||||
for i, char in enumerate(f):
|
|
||||||
vocab_char_map[char[:-1]] = i
|
|
||||||
vocab_size = len(vocab_char_map)
|
|
||||||
return vocab_char_map, vocab_size
|
|
||||||
|
|
||||||
|
|
||||||
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
|
||||||
final_reference_target_texts_list = []
|
|
||||||
custom_trans = str.maketrans(
|
|
||||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
|
||||||
) # add custom trans here, to address oov
|
|
||||||
|
|
||||||
def is_chinese(c):
|
|
||||||
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
|
||||||
|
|
||||||
for text in reference_target_texts_list:
|
|
||||||
char_list = []
|
|
||||||
text = text.translate(custom_trans)
|
|
||||||
for seg in jieba.cut(text):
|
|
||||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
|
||||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
|
||||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
|
||||||
char_list.append(" ")
|
|
||||||
char_list.extend(seg)
|
|
||||||
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
|
|
||||||
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
|
||||||
for i, c in enumerate(seg):
|
|
||||||
if is_chinese(c):
|
|
||||||
char_list.append(" ")
|
|
||||||
char_list.append(seg_[i])
|
|
||||||
else: # if mixed characters, alphabets and symbols
|
|
||||||
for c in seg:
|
|
||||||
if ord(c) < 256:
|
|
||||||
char_list.extend(c)
|
|
||||||
elif is_chinese(c):
|
|
||||||
char_list.append(" ")
|
|
||||||
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
|
||||||
else:
|
|
||||||
char_list.append(c)
|
|
||||||
final_reference_target_texts_list.append(char_list)
|
|
||||||
|
|
||||||
return final_reference_target_texts_list
|
|
||||||
|
|
||||||
|
|
||||||
def list_str_to_idx(
|
|
||||||
text: Union[List[str], List[List[str]]],
|
|
||||||
vocab_char_map: Dict[str, int], # {char: idx}
|
|
||||||
padding_value=-1,
|
|
||||||
):
|
|
||||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
|
||||||
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
|
||||||
return list_idx_tensors
|
|
||||||
|
|
||||||
|
|
||||||
def load_vocoder(
|
def load_vocoder(
|
||||||
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
||||||
):
|
):
|
||||||
@@ -316,29 +246,11 @@ def load_vocoder(
|
|||||||
return vocoder
|
return vocoder
|
||||||
|
|
||||||
|
|
||||||
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
|
|
||||||
if vocoder == "vocos":
|
|
||||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
|
||||||
sample_rate=24000,
|
|
||||||
n_fft=1024,
|
|
||||||
win_length=1024,
|
|
||||||
hop_length=256,
|
|
||||||
n_mels=100,
|
|
||||||
power=1,
|
|
||||||
center=True,
|
|
||||||
normalized=False,
|
|
||||||
norm=None,
|
|
||||||
).to(device)
|
|
||||||
mel = mel_stft(waveform.to(device))
|
|
||||||
mel = mel.clamp(min=1e-5).log()
|
|
||||||
return mel.transpose(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class VocosTensorRT:
|
class VocosTensorRT:
|
||||||
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
||||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||||
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
||||||
logger.info(f"Loading vae engine from {engine_path}")
|
logger.info(f"Loading vocoder engine from {engine_path}")
|
||||||
self.engine_path = engine_path
|
self.engine_path = engine_path
|
||||||
with open(engine_path, "rb") as f:
|
with open(engine_path, "rb") as f:
|
||||||
engine_buffer = f.read()
|
engine_buffer = f.read()
|
||||||
@@ -368,20 +280,20 @@ def main():
|
|||||||
world_size, local_rank, rank = init_distributed()
|
world_size, local_rank, rank = init_distributed()
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
|
||||||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
|
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
|
||||||
|
|
||||||
tllm_model_dir = args.tllm_model_dir
|
|
||||||
config_file = os.path.join(tllm_model_dir, "config.json")
|
|
||||||
with open(config_file) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
if args.backend_type == "trt":
|
if args.backend_type == "trt":
|
||||||
|
tllm_model_dir = args.tllm_model_dir
|
||||||
|
with open(os.path.join(tllm_model_dir, "config.json")) as f:
|
||||||
|
tllm_model_config = json.load(f)
|
||||||
model = F5TTS(
|
model = F5TTS(
|
||||||
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
|
tllm_model_config,
|
||||||
|
debug_mode=False,
|
||||||
|
tllm_model_dir=tllm_model_dir,
|
||||||
|
model_path=args.model_path,
|
||||||
|
vocab_size=vocab_size,
|
||||||
)
|
)
|
||||||
elif args.backend_type == "pytorch":
|
elif args.backend_type == "pytorch":
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
|
||||||
from f5_tts.infer.utils_infer import load_model
|
from f5_tts.infer.utils_infer import load_model
|
||||||
from f5_tts.model import DiT
|
from f5_tts.model import DiT
|
||||||
|
|
||||||
@@ -445,20 +357,23 @@ def main():
|
|||||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||||
|
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||||
if args.backend_type == "trt":
|
if args.backend_type == "trt":
|
||||||
_ = model.sample(
|
_ = model.sample(
|
||||||
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
|
text_pad_seq,
|
||||||
|
cond_pad_seq,
|
||||||
|
ref_mel_lens,
|
||||||
|
total_mel_lens,
|
||||||
|
remove_input_padding=args.remove_input_padding,
|
||||||
)
|
)
|
||||||
elif args.backend_type == "pytorch":
|
elif args.backend_type == "pytorch":
|
||||||
with torch.inference_mode():
|
|
||||||
text_pad_seq -= 1
|
|
||||||
text_pad_seq[text_pad_seq == -2] = -1
|
|
||||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||||
|
with torch.inference_mode():
|
||||||
generated, _ = model.sample(
|
generated, _ = model.sample(
|
||||||
cond=ref_mels,
|
cond=ref_mels,
|
||||||
text=text_pad_seq,
|
text=text_pad_seq,
|
||||||
duration=total_mel_lens,
|
duration=total_mel_lens,
|
||||||
steps=16,
|
steps=32,
|
||||||
cfg_strength=2.0,
|
cfg_strength=2.0,
|
||||||
sway_sampling_coef=-1,
|
sway_sampling_coef=-1,
|
||||||
)
|
)
|
||||||
@@ -478,13 +393,13 @@ def main():
|
|||||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||||
|
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||||
if args.use_perf:
|
if args.use_perf:
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
if args.backend_type == "trt":
|
if args.backend_type == "trt":
|
||||||
generated, cost_time = model.sample(
|
generated, cost_time = model.sample(
|
||||||
text_pad_seq,
|
text_pad_seq,
|
||||||
ref_mels,
|
cond_pad_seq,
|
||||||
ref_mel_lens,
|
ref_mel_lens,
|
||||||
total_mel_lens,
|
total_mel_lens,
|
||||||
remove_input_padding=args.remove_input_padding,
|
remove_input_padding=args.remove_input_padding,
|
||||||
@@ -494,20 +409,20 @@ def main():
|
|||||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
text_pad_seq -= 1
|
|
||||||
text_pad_seq[text_pad_seq == -2] = -1
|
|
||||||
generated, _ = model.sample(
|
generated, _ = model.sample(
|
||||||
cond=ref_mels,
|
cond=ref_mels,
|
||||||
text=text_pad_seq,
|
text=text_pad_seq,
|
||||||
duration=total_mel_lens,
|
duration=total_mel_lens,
|
||||||
lens=ref_mel_lens,
|
lens=ref_mel_lens,
|
||||||
steps=16,
|
steps=32,
|
||||||
cfg_strength=2.0,
|
cfg_strength=2.0,
|
||||||
sway_sampling_coef=-1,
|
sway_sampling_coef=-1,
|
||||||
)
|
)
|
||||||
cost_time = time.time() - start_time
|
cost_time = time.time() - start_time
|
||||||
decoding_time += cost_time
|
decoding_time += cost_time
|
||||||
vocoder_start_time = time.time()
|
vocoder_start_time = time.time()
|
||||||
|
target_rms = 0.1
|
||||||
|
target_sample_rate = 24000
|
||||||
for i, gen in enumerate(generated):
|
for i, gen in enumerate(generated):
|
||||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||||
@@ -519,13 +434,10 @@ def main():
|
|||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
else:
|
else:
|
||||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||||
target_rms = 0.1
|
|
||||||
target_sample_rate = 24_000
|
if batch["ref_rms_list"][i] < target_rms:
|
||||||
# if ref_rms_list[i] < target_rms:
|
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
|
||||||
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
|
||||||
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
|
|
||||||
if rms < target_rms:
|
|
||||||
generated_wave = generated_wave * target_rms / rms
|
|
||||||
utt = batch["ids"][i]
|
utt = batch["ids"][i]
|
||||||
torchaudio.save(
|
torchaudio.save(
|
||||||
f"{args.output_dir}/{utt}.wav",
|
f"{args.output_dir}/{utt}.wav",
|
||||||
|
|||||||
@@ -30,15 +30,6 @@ python3 client_grpc.py \
|
|||||||
--huggingface-dataset yuekai/seed_tts \
|
--huggingface-dataset yuekai/seed_tts \
|
||||||
--split-name test_zh \
|
--split-name test_zh \
|
||||||
--log-dir ./log_concurrent_tasks_${num_task}
|
--log-dir ./log_concurrent_tasks_${num_task}
|
||||||
|
|
||||||
# For offline Spark-TTS-0.5B
|
|
||||||
python3 client_grpc.py \
|
|
||||||
--server-addr localhost \
|
|
||||||
--model-name spark_tts \
|
|
||||||
--num-tasks $num_task \
|
|
||||||
--huggingface-dataset yuekai/seed_tts \
|
|
||||||
--split-name wenetspeech4tts \
|
|
||||||
--log-dir ./log_concurrent_tasks_${num_task}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -176,8 +167,7 @@ def get_args():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="f5_tts",
|
default="f5_tts",
|
||||||
choices=["f5_tts", "spark_tts"],
|
help="triton model_repo module name to request",
|
||||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -206,7 +196,7 @@ def get_args():
|
|||||||
"--log-dir",
|
"--log-dir",
|
||||||
type=str,
|
type=str,
|
||||||
required=False,
|
required=False,
|
||||||
default="./tmp",
|
default="./tests/client_grpc",
|
||||||
help="log directory",
|
help="log directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -65,14 +66,13 @@ def get_args():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="f5_tts",
|
default="f5_tts",
|
||||||
choices=["f5_tts", "spark_tts"],
|
|
||||||
help="triton model_repo module name to request",
|
help="triton model_repo module name to request",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-audio",
|
"--output-audio",
|
||||||
type=str,
|
type=str,
|
||||||
default="output.wav",
|
default="tests/client_http.wav",
|
||||||
help="Path to save the output audio",
|
help="Path to save the output audio",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@@ -140,4 +140,5 @@ if __name__ == "__main__":
|
|||||||
result = rsp.json()
|
result = rsp.json()
|
||||||
audio = result["outputs"][0]["data"]
|
audio = result["outputs"][0]["data"]
|
||||||
audio = np.array(audio, dtype=np.float32)
|
audio = np.array(audio, dtype=np.float32)
|
||||||
|
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
|
||||||
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch.nn.functional as F
|
|||||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.runtime.session import Session
|
from tensorrt_llm.runtime.session import Session
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||||
@@ -38,20 +39,15 @@ class TextEmbedding(nn.Module):
|
|||||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
||||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text, seq_len):
|
||||||
# only keep tensors with value not -1
|
text = text + 1
|
||||||
text_mask = text != -1
|
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||||
text_pad_cut_off_index = text_mask.sum(dim=1).max()
|
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
|
||||||
|
|
||||||
|
text = self.text_embed(text) # b n -> b n d
|
||||||
|
text = text + self.freqs_cis[:seq_len, :]
|
||||||
|
text = self.text_blocks(text)
|
||||||
|
|
||||||
text = text[:, :text_pad_cut_off_index]
|
|
||||||
text = self.text_embed(text)
|
|
||||||
text = text + self.freqs_cis[: text.shape[1], :]
|
|
||||||
for block in self.text_blocks:
|
|
||||||
text = block(text)
|
|
||||||
# padding text to the original length
|
|
||||||
# text shape: B,seq_len,C
|
|
||||||
# pad at the second dimension
|
|
||||||
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@@ -112,20 +108,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
|||||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(ckpt_path, use_ema=True):
|
def get_text_embed_dict(ckpt_path, use_ema=True):
|
||||||
checkpoint = torch.load(ckpt_path, weights_only=True)
|
ckpt_type = ckpt_path.split(".")[-1]
|
||||||
|
if ckpt_type == "safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
checkpoint = load_file(ckpt_path)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
if use_ema:
|
if use_ema:
|
||||||
|
if ckpt_type == "safetensors":
|
||||||
|
checkpoint = {"ema_model_state_dict": checkpoint}
|
||||||
checkpoint["model_state_dict"] = {
|
checkpoint["model_state_dict"] = {
|
||||||
k.replace("ema_model.", ""): v
|
k.replace("ema_model.", ""): v
|
||||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||||
if k not in ["initted", "step"]
|
if k not in ["initted", "step"]
|
||||||
}
|
}
|
||||||
dict_state = checkpoint["model_state_dict"]
|
else:
|
||||||
|
if ckpt_type == "safetensors":
|
||||||
|
checkpoint = {"model_state_dict": checkpoint}
|
||||||
|
model_params = checkpoint["model_state_dict"]
|
||||||
|
|
||||||
text_embed_dict = {}
|
text_embed_dict = {}
|
||||||
for key in dict_state.keys():
|
for key in model_params.keys():
|
||||||
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
||||||
if "text_embed" in key:
|
if "text_embed" in key:
|
||||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
|
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
|
||||||
return text_embed_dict
|
return text_embed_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -196,15 +205,14 @@ class F5TTS(object):
|
|||||||
self.text_embedding = TextEmbedding(
|
self.text_embedding = TextEmbedding(
|
||||||
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
|
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
|
||||||
|
|
||||||
self.target_audio_sample_rate = 24000
|
# self.target_audio_sample_rate = 24000
|
||||||
self.target_rms = 0.15 # target rms for audio
|
# self.target_rms = 0.1 # least rms when inference, normalize to if lower
|
||||||
self.n_fft = 1024
|
# self.n_fft = 1024
|
||||||
self.win_length = 1024
|
# self.win_length = 1024
|
||||||
self.hop_length = 256
|
# self.hop_length = 256
|
||||||
self.n_mel_channels = 100
|
self.n_mel_channels = 100
|
||||||
# self.max_mel_len = 3000
|
|
||||||
self.head_dim = 64
|
self.head_dim = 64
|
||||||
self.base_rescale_factor = 1.0
|
self.base_rescale_factor = 1.0
|
||||||
self.interpolation_factor = 1.0
|
self.interpolation_factor = 1.0
|
||||||
@@ -214,12 +222,21 @@ class F5TTS(object):
|
|||||||
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
||||||
self.rope_cos = self.freqs.cos().half()
|
self.rope_cos = self.freqs.cos().half()
|
||||||
self.rope_sin = self.freqs.sin().half()
|
self.rope_sin = self.freqs.sin().half()
|
||||||
self.nfe_steps = 16
|
|
||||||
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
|
self.nfe_steps = 32
|
||||||
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
|
epss = {
|
||||||
|
5: [0, 2, 4, 8, 16, 32],
|
||||||
|
6: [0, 2, 4, 6, 8, 16, 32],
|
||||||
|
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
||||||
|
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
||||||
|
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||||
|
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||||
|
}
|
||||||
|
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
|
||||||
|
time_step = 1 - torch.cos(torch.pi * t / 2)
|
||||||
delta_t = torch.diff(time_step)
|
delta_t = torch.diff(time_step)
|
||||||
# WAR: hard coding 256 here
|
|
||||||
tmp_dim = 256
|
tmp_dim = 256 # WAR: hard coding 256 here
|
||||||
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
||||||
half_dim = tmp_dim // 2
|
half_dim = tmp_dim // 2
|
||||||
emb_factor = math.log(10000) / (half_dim - 1)
|
emb_factor = math.log(10000) / (half_dim - 1)
|
||||||
@@ -344,7 +361,7 @@ class F5TTS(object):
|
|||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
text_pad_sequence: torch.Tensor,
|
text_pad_sequence: torch.Tensor,
|
||||||
ref_mel_batch: torch.Tensor,
|
cond_pad_sequence: torch.Tensor,
|
||||||
ref_mel_len_batch: torch.Tensor,
|
ref_mel_len_batch: torch.Tensor,
|
||||||
estimated_reference_target_mel_len: List[int],
|
estimated_reference_target_mel_len: List[int],
|
||||||
remove_input_padding: bool = False,
|
remove_input_padding: bool = False,
|
||||||
@@ -353,26 +370,43 @@ class F5TTS(object):
|
|||||||
if use_perf:
|
if use_perf:
|
||||||
torch.cuda.nvtx.range_push("text embedding")
|
torch.cuda.nvtx.range_push("text embedding")
|
||||||
batch = text_pad_sequence.shape[0]
|
batch = text_pad_sequence.shape[0]
|
||||||
max_seq_len = ref_mel_batch.shape[1]
|
max_seq_len = cond_pad_sequence.shape[1]
|
||||||
|
|
||||||
text_pad_sequence_drop = torch.cat(
|
# get text_embed one by one to avoid misalignment
|
||||||
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
|
text_and_drop_embedding_list = []
|
||||||
|
for i in range(batch):
|
||||||
|
text_and_drop_embedding_i = self.text_embedding(
|
||||||
|
torch.cat(
|
||||||
|
(
|
||||||
|
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||||
|
torch.full((1, text_pad_sequence.shape[1]), -1, dtype=torch.int32).to(self.device),
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
estimated_reference_target_mel_len[i],
|
||||||
)
|
)
|
||||||
|
text_and_drop_embedding_list.extend([text_and_drop_embedding_i[0], text_and_drop_embedding_i[1]])
|
||||||
|
|
||||||
text_embedding_drop_list = []
|
# pad separately computed text_embed to form batch with max_seq_len
|
||||||
for i in range(batch + 1):
|
text_and_drop_embedding = pad_sequence(
|
||||||
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
|
text_and_drop_embedding_list,
|
||||||
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
batch_first=True,
|
||||||
|
padding_value=0,
|
||||||
|
)
|
||||||
|
text_embedding = text_and_drop_embedding[0::2]
|
||||||
|
text_embedding_drop = text_and_drop_embedding[1::2]
|
||||||
|
|
||||||
text_embedding = text_embedding_drop_condition[:-1]
|
noise = torch.randn_like(cond_pad_sequence).to(self.device)
|
||||||
# text_embedding_drop B,T,C batch should be the same
|
|
||||||
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
|
|
||||||
|
|
||||||
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
|
||||||
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||||
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||||
|
|
||||||
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
cat_mel_text = torch.cat(
|
||||||
|
(
|
||||||
|
cond_pad_sequence,
|
||||||
|
text_embedding,
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
cat_mel_text_drop = torch.cat(
|
cat_mel_text_drop = torch.cat(
|
||||||
(
|
(
|
||||||
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
|
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ import os
|
|||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
from f5_tts_trtllm import F5TTS
|
from f5_tts_trtllm import F5TTS
|
||||||
@@ -99,7 +98,8 @@ def list_str_to_idx(
|
|||||||
padding_value=-1,
|
padding_value=-1,
|
||||||
): # noqa: F722
|
): # noqa: F722
|
||||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||||
return list_idx_tensors
|
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
@@ -107,12 +107,12 @@ class TritonPythonModel:
|
|||||||
self.use_perf = True
|
self.use_perf = True
|
||||||
self.device = torch.device("cuda")
|
self.device = torch.device("cuda")
|
||||||
self.target_audio_sample_rate = 24000
|
self.target_audio_sample_rate = 24000
|
||||||
self.target_rms = 0.15 # target rms for audio
|
self.target_rms = 0.1 # least rms when inference, normalize to if lower
|
||||||
self.n_fft = 1024
|
self.n_fft = 1024
|
||||||
self.win_length = 1024
|
self.win_length = 1024
|
||||||
self.hop_length = 256
|
self.hop_length = 256
|
||||||
self.n_mel_channels = 100
|
self.n_mel_channels = 100
|
||||||
self.max_mel_len = 3000
|
self.max_mel_len = 4096
|
||||||
self.head_dim = 64
|
self.head_dim = 64
|
||||||
|
|
||||||
parameters = json.loads(args["model_config"])["parameters"]
|
parameters = json.loads(args["model_config"])["parameters"]
|
||||||
@@ -181,7 +181,8 @@ class TritonPythonModel:
|
|||||||
reference_target_texts_list,
|
reference_target_texts_list,
|
||||||
estimated_reference_target_mel_len,
|
estimated_reference_target_mel_len,
|
||||||
reference_mel_len,
|
reference_mel_len,
|
||||||
) = [], [], [], [], []
|
reference_rms_list,
|
||||||
|
) = [], [], [], [], [], []
|
||||||
mel_features_list = []
|
mel_features_list = []
|
||||||
if self.use_perf:
|
if self.use_perf:
|
||||||
torch.cuda.nvtx.range_push("preprocess")
|
torch.cuda.nvtx.range_push("preprocess")
|
||||||
@@ -208,6 +209,7 @@ class TritonPythonModel:
|
|||||||
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
|
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
|
||||||
if ref_rms < self.target_rms:
|
if ref_rms < self.target_rms:
|
||||||
wav = wav * self.target_rms / ref_rms
|
wav = wav * self.target_rms / ref_rms
|
||||||
|
reference_rms_list.append(ref_rms)
|
||||||
if self.reference_sample_rate != self.target_audio_sample_rate:
|
if self.reference_sample_rate != self.target_audio_sample_rate:
|
||||||
wav = self.resampler(wav)
|
wav = self.resampler(wav)
|
||||||
wav = wav.to(self.device)
|
wav = wav.to(self.device)
|
||||||
@@ -237,15 +239,6 @@ class TritonPythonModel:
|
|||||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||||
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
||||||
|
|
||||||
for i, item in enumerate(text_pad_sequence):
|
|
||||||
text_pad_sequence[i] = F.pad(
|
|
||||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
|
||||||
)
|
|
||||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
|
||||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
|
|
||||||
text_pad_sequence = F.pad(
|
|
||||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
|
||||||
)
|
|
||||||
if self.use_perf:
|
if self.use_perf:
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
@@ -266,9 +259,8 @@ class TritonPythonModel:
|
|||||||
estimated_mel_len = estimated_reference_target_mel_len[i]
|
estimated_mel_len = estimated_reference_target_mel_len[i]
|
||||||
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||||
audio = self.forward_vocoder(denoised_one_item)
|
audio = self.forward_vocoder(denoised_one_item)
|
||||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
if reference_rms_list[i] < self.target_rms:
|
||||||
if rms < self.target_rms:
|
audio = audio * reference_rms_list[i] / self.target_rms
|
||||||
audio = audio * self.target_rms / rms
|
|
||||||
|
|
||||||
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
|
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class F5TTS(PretrainedModel):
|
|||||||
max_batch_size = kwargs["max_batch_size"]
|
max_batch_size = kwargs["max_batch_size"]
|
||||||
batch_size_range = [2, 2, max_batch_size]
|
batch_size_range = [2, 2, max_batch_size]
|
||||||
mel_size = 100
|
mel_size = 100
|
||||||
max_seq_len = 3000
|
max_seq_len = 4096
|
||||||
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
|
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
|
||||||
hidden_size = 512
|
hidden_size = 512
|
||||||
concat_feature_dim = mel_size + hidden_size
|
concat_feature_dim = mel_size + hidden_size
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
accelerate>=0.33.0
|
|
||||||
bitsandbytes>0.37.0
|
|
||||||
cached_path
|
|
||||||
click
|
|
||||||
datasets
|
|
||||||
ema_pytorch>=0.5.2
|
|
||||||
gradio>=3.45.2
|
|
||||||
hydra-core>=1.3.0
|
|
||||||
jieba
|
|
||||||
librosa
|
|
||||||
matplotlib
|
|
||||||
numpy<=1.26.4
|
|
||||||
pydub
|
|
||||||
pypinyin
|
|
||||||
safetensors
|
|
||||||
soundfile
|
|
||||||
tomli
|
|
||||||
torch>=2.0.0
|
|
||||||
# torchaudio>=2.0.0
|
|
||||||
torchdiffeq
|
|
||||||
tqdm>=4.65.0
|
|
||||||
transformers
|
|
||||||
x_transformers>=1.31.14
|
|
||||||
packaging>=24.2
|
|
||||||
@@ -1,64 +1,66 @@
|
|||||||
stage=$1
|
stage=$1
|
||||||
stop_stage=$2
|
stop_stage=$2
|
||||||
model=$3 # F5TTS_Base
|
model=$3 # F5TTS_v1_Base | F5TTS_Base
|
||||||
if [ -z "$model" ]; then
|
if [ -z "$model" ]; then
|
||||||
echo "Model is none, using default model F5TTS_Base"
|
model=F5TTS_v1_Base
|
||||||
model=F5TTS_Base
|
|
||||||
fi
|
fi
|
||||||
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
||||||
export CUDA_VISIBLE_DEVICES=0
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
|
CKPT_DIR=../../../../ckpts
|
||||||
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
|
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
|
||||||
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
|
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
|
||||||
|
|
||||||
vocoder_trt_engine_path=vocos_vocoder.plan
|
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
|
||||||
model_repo=./model_repo
|
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
|
||||||
|
MODEL_REPO=./model_repo
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
echo "Downloading f5 tts from huggingface"
|
echo "Downloading F5-TTS from huggingface"
|
||||||
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
|
||||||
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
|
||||||
|
vocab_file=$CKPT_DIR/$model/vocab.txt
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
echo "Converting checkpoint"
|
echo "Converting checkpoint"
|
||||||
python3 ./scripts/convert_checkpoint.py \
|
python3 scripts/convert_checkpoint.py \
|
||||||
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
|
--timm_ckpt $ckpt_file \
|
||||||
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
|
--output_dir $TRTLLM_CKPT_DIR --model_name $model
|
||||||
python_package_path=/usr/local/lib/python3.12/dist-packages
|
python_package_path=/usr/local/lib/python3.12/dist-packages
|
||||||
cp -r patch/* $python_package_path/tensorrt_llm/models
|
cp -r patch/* $python_package_path/tensorrt_llm/models
|
||||||
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
|
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
|
||||||
--max_batch_size 8 \
|
--max_batch_size 8 \
|
||||||
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
|
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
echo "Exporting vocos vocoder"
|
echo "Exporting vocos vocoder"
|
||||||
onnx_vocoder_path=vocos_vocoder.onnx
|
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
|
||||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
|
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
|
||||||
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
echo "Building triton server"
|
echo "Building triton server"
|
||||||
rm -r $model_repo
|
rm -r $MODEL_REPO
|
||||||
cp -r ./model_repo_f5_tts $model_repo
|
cp -r ./model_repo_f5_tts $MODEL_REPO
|
||||||
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
|
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
|
||||||
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
|
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
echo "Starting triton server"
|
echo "Starting triton server"
|
||||||
tritonserver --model-repository=$model_repo
|
tritonserver --model-repository=$MODEL_REPO
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
echo "Testing triton server"
|
echo "Testing triton server"
|
||||||
num_task=1
|
num_task=1
|
||||||
log_dir=./log_concurrent_tasks_${num_task}
|
split_name=wenetspeech4tts
|
||||||
|
log_dir=./tests/client_grpc_concurrent_${num_task}_${split_name}
|
||||||
rm -r $log_dir
|
rm -r $log_dir
|
||||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
@@ -74,37 +76,37 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
|||||||
batch_size=1
|
batch_size=1
|
||||||
split_name=wenetspeech4tts
|
split_name=wenetspeech4tts
|
||||||
backend_type=trt
|
backend_type=trt
|
||||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||||
rm -r $log_dir
|
rm -r $log_dir
|
||||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
|
||||||
torchrun --nproc_per_node=1 \
|
torchrun --nproc_per_node=1 \
|
||||||
benchmark.py --output-dir $log_dir \
|
benchmark.py --output-dir $log_dir \
|
||||||
--batch-size $batch_size \
|
--batch-size $batch_size \
|
||||||
--enable-warmup \
|
--enable-warmup \
|
||||||
--split-name $split_name \
|
--split-name $split_name \
|
||||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
--model-path $ckpt_file \
|
||||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
--vocab-file $vocab_file \
|
||||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
|
||||||
--backend-type $backend_type \
|
--backend-type $backend_type \
|
||||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
echo "Native Pytorch: offline decoding benchmark test"
|
echo "Native Pytorch: offline decoding benchmark test"
|
||||||
pip install -r requirements-pytorch.txt
|
if ! python3 -c "import f5_tts" &> /dev/null; then
|
||||||
batch_size=1
|
pip install -e ../../../../
|
||||||
|
fi
|
||||||
|
batch_size=1 # set attn_mask_enabled=True if batched
|
||||||
split_name=wenetspeech4tts
|
split_name=wenetspeech4tts
|
||||||
backend_type=pytorch
|
backend_type=pytorch
|
||||||
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||||
rm -r $log_dir
|
rm -r $log_dir
|
||||||
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
|
|
||||||
torchrun --nproc_per_node=1 \
|
torchrun --nproc_per_node=1 \
|
||||||
benchmark.py --output-dir $log_dir \
|
benchmark.py --output-dir $log_dir \
|
||||||
--batch-size $batch_size \
|
--batch-size $batch_size \
|
||||||
--split-name $split_name \
|
--split-name $split_name \
|
||||||
--enable-warmup \
|
--enable-warmup \
|
||||||
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
|
--model-path $ckpt_file \
|
||||||
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
|
--vocab-file $vocab_file \
|
||||||
--backend-type $backend_type \
|
--backend-type $backend_type \
|
||||||
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
|
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||||
fi
|
fi
|
||||||
@@ -172,11 +172,12 @@ def parse_arguments():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name",
|
"--model_name",
|
||||||
type=str,
|
type=str,
|
||||||
default="F5TTS_Base",
|
default="F5TTS_v1_Base",
|
||||||
choices=[
|
choices=[
|
||||||
|
"F5TTS_v1_Base",
|
||||||
"F5TTS_Base",
|
"F5TTS_Base",
|
||||||
],
|
],
|
||||||
) # TODO: support F5TTS_v1_Base
|
)
|
||||||
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
|
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
|
||||||
@@ -184,7 +185,6 @@ def parse_arguments():
|
|||||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||||
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
|
||||||
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
||||||
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
||||||
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
||||||
@@ -197,18 +197,29 @@ def parse_arguments():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def convert_timm_dit(args, mapping, dtype="float32"):
|
def convert_timm_dit(args, mapping, dtype="float32", use_ema=True):
|
||||||
weights = {}
|
weights = {}
|
||||||
tik = time.time()
|
tik = time.time()
|
||||||
torch_dtype = str_dtype_to_torch(dtype)
|
torch_dtype = str_dtype_to_torch(dtype)
|
||||||
tensor_parallel = mapping.tp_size
|
tensor_parallel = mapping.tp_size
|
||||||
|
|
||||||
model_params = dict(torch.load(args.timm_ckpt))
|
ckpt_path = args.timm_ckpt
|
||||||
|
ckpt_type = ckpt_path.split(".")[-1]
|
||||||
|
if ckpt_type == "safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
model_params = load_file(ckpt_path)
|
||||||
|
else:
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
|
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
|
||||||
|
|
||||||
|
prefix = "ema_model.transformer." if use_ema else "transformer."
|
||||||
|
if any(k.startswith(prefix) for k in model_params.keys()):
|
||||||
model_params = {
|
model_params = {
|
||||||
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
|
key[len(prefix) :] if key.startswith(prefix) else key: value
|
||||||
|
for key, value in model_params.items()
|
||||||
|
if key.startswith(prefix)
|
||||||
}
|
}
|
||||||
prefix = "ema_model.transformer."
|
|
||||||
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
|
|
||||||
|
|
||||||
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
||||||
|
|
||||||
@@ -230,7 +241,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
|
|||||||
|
|
||||||
assert len(weights) == len(model_params)
|
assert len(weights) == len(model_params)
|
||||||
|
|
||||||
# new_prefix = 'f5_transformer.'
|
# new_prefix = "f5_transformer."
|
||||||
new_prefix = ""
|
new_prefix = ""
|
||||||
weights = {new_prefix + key: value for key, value in weights.items()}
|
weights = {new_prefix + key: value for key, value in weights.items()}
|
||||||
import math
|
import math
|
||||||
@@ -278,7 +289,7 @@ def save_config(args):
|
|||||||
"num_hidden_layers": 22,
|
"num_hidden_layers": 22,
|
||||||
"num_attention_heads": 16,
|
"num_attention_heads": 16,
|
||||||
"dim_head": 64,
|
"dim_head": 64,
|
||||||
"dropout": 0.1,
|
"dropout": 0.0, # 0.1
|
||||||
"ff_mult": 2,
|
"ff_mult": 2,
|
||||||
"mel_dim": 100,
|
"mel_dim": 100,
|
||||||
"text_num_embeds": 256,
|
"text_num_embeds": 256,
|
||||||
@@ -296,7 +307,7 @@ def save_config(args):
|
|||||||
config["quantization"] = {
|
config["quantization"] = {
|
||||||
"quant_algo": "FP8",
|
"quant_algo": "FP8",
|
||||||
# TODO: add support for exclude modules.
|
# TODO: add support for exclude modules.
|
||||||
# 'exclude_modules': "*final_layer*",
|
# "exclude_modules": "*final_layer*",
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Manual installation of TensorRT, in case not using NVIDIA NGC:
|
||||||
|
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
|
||||||
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
|
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
|
||||||
|
|
||||||
ONNX_PATH=$1
|
ONNX_PATH=$1
|
||||||
@@ -40,4 +42,3 @@ ${TRTEXEC} \
|
|||||||
--maxShapes="mel:${MEL_MAX_SHAPE}" \
|
--maxShapes="mel:${MEL_MAX_SHAPE}" \
|
||||||
--onnx=${ONNX_PATH} \
|
--onnx=${ONNX_PATH} \
|
||||||
--saveEngine=${ENGINE_PATH}
|
--saveEngine=${ENGINE_PATH}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user