diff --git a/src/f5_tts/runtime/triton_trtllm/README.md b/src/f5_tts/runtime/triton_trtllm/README.md index ca69e5f..777f6a8 100644 --- a/src/f5_tts/runtime/triton_trtllm/README.md +++ b/src/f5_tts/runtime/triton_trtllm/README.md @@ -35,9 +35,9 @@ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_t ### Benchmark Results Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs. -| Model | Note | Concurrency | Avg Latency | RTF | -|-------|-----------|-----------------------|---------|--| -| F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394| +| Model | Concurrency | Avg Latency | RTF | +|-------|-------------|-----------------|--| +| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394| ### Credits 1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/client_grpc.py b/src/f5_tts/runtime/triton_trtllm/client_grpc.py index 0d3b154..c4e2c43 100644 --- a/src/f5_tts/runtime/triton_trtllm/client_grpc.py +++ b/src/f5_tts/runtime/triton_trtllm/client_grpc.py @@ -245,6 +245,7 @@ async def send( model_name: str, padding_duration: int = None, audio_save_dir: str = "./", + save_sample_rate: int = 16000, ): total_duration = 0.0 latency_data = [] @@ -267,7 +268,9 @@ async def send( samples = np.zeros( ( 1, - padding_duration * sample_rate * ((int(duration) // padding_duration) + 1), + padding_duration + * sample_rate + * ((int(estimated_target_duration + duration) // padding_duration) + 1), ), dtype=np.float32, ) @@ -306,7 +309,7 @@ async def send( end = time.time() - start audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") - sf.write(audio_save_path, audio, 16000, "PCM_16") + sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") latency_data.append((end, estimated_target_duration)) total_duration += estimated_target_duration @@ -413,7 +416,8 @@ async def main(): log_interval=args.log_interval, model_name=args.model_name, audio_save_dir=args.log_dir, - padding_duration=1.0, + padding_duration=1, + save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, ) ) tasks.append(task) diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py index a0ca9d3..9265886 100644 --- a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py @@ -158,7 +158,7 @@ class TritonPythonModel: return mel.transpose(1, 2) def forward_vocoder(self, mel): - mel = mel.to(torch.float32).contiguous() + mel = mel.to(torch.float32).contiguous().cpu() input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)) inference_request = pb_utils.InferenceRequest( diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt index 171211e..4663f7c 100644 --- a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt @@ -14,9 +14,9 @@ name: "f5_tts" backend: "python" -max_batch_size: 1 +max_batch_size: 4 dynamic_batching { - max_queue_delay_microseconds: 1 + max_queue_delay_microseconds: 1000 } parameters [ { diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py index b89ca5c..26c8bc9 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -30,8 +30,7 @@ class InputEmbedding(Module): self.proj = Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x, cond, drop_audio_cond=False): - # if drop_audio_cond: # cfg for cond audio + def forward(self, x, cond): x = self.proj(concat([x, cond], dim=-1)) return self.conv_pos_embed(x) + x @@ -41,9 +40,8 @@ class F5TTS(PretrainedModel): super().__init__(config) self.dtype = str_dtype_to_trt(config.dtype) - self.time_embed = TimestepEmbedding(config.hidden_size) # √ - text_dim = config.mel_dim if config.text_dim is None else config.text_dim - self.input_embed = InputEmbedding(config.mel_dim, text_dim, config.hidden_size) + self.time_embed = TimestepEmbedding(config.hidden_size) + self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) self.dim = config.hidden_size self.depth = config.num_hidden_layers diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py index 5bfd5a0..a0051b4 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py @@ -93,7 +93,7 @@ class ConvPositionEmbedding(Module): self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) self.mish = Mish() - def forward(self, x, mask): # noqa: F722 + def forward(self, x, mask=None): # noqa: F722 if default_net().plugin_config.remove_input_padding: x = unsqueeze(x, 0) x = permute(x, [0, 2, 1]) diff --git a/src/f5_tts/runtime/triton_trtllm/run.sh b/src/f5_tts/runtime/triton_trtllm/run.sh index bf5526d..270c4f5 100644 --- a/src/f5_tts/runtime/triton_trtllm/run.sh +++ b/src/f5_tts/runtime/triton_trtllm/run.sh @@ -14,23 +14,22 @@ F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine -num_task=2 -log_dir=./log_concurrent_tasks_${num_task} vocoder_trt_engine_path=vocos_vocoder.plan model_repo=./model_repo if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - echo "Copying f5 tts trtllm files" - python_package_path=/usr/local/lib/python3.12/dist-packages - cp -r patch/* $python_package_path/tensorrt_llm/models + echo "Downloading f5 tts from huggingface" + huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH + fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - echo "Downloading f5 tts from huggingface" - huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH + echo "Converting checkpoint" python3 ./scripts/convert_checkpoint.py \ --timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \ --output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model + python_package_path=/usr/local/lib/python3.12/dist-packages + cp -r patch/* $python_package_path/tensorrt_llm/models trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \ --max_batch_size 8 \ --output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable @@ -58,5 +57,8 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then echo "Testing triton server" + num_task=1 + log_dir=./log_concurrent_tasks_${num_task} + rm -r $log_dir python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir fi \ No newline at end of file