From d9dfbe47cc076d3cb9f825e383cd2e573d2b7a6b Mon Sep 17 00:00:00 2001 From: SWivid Date: Thu, 3 Apr 2025 14:36:22 +0800 Subject: [PATCH] Update README.md --- README.md | 18 +- src/f5_tts/model/cfm.py | 2 +- src/f5_tts/runtime/triton_trtllm/README.md | 11 +- .../runtime/triton_trtllm/patch/__init__.py | 272 +++++++++--------- 4 files changed, 159 insertions(+), 144 deletions(-) diff --git a/README.md b/README.md index ea0793f..86d8ccf 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,9 @@ docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,targ ## Inference +- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer). +- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful. + ### 1. Gradio App Currently supported features: @@ -176,10 +179,18 @@ f5-tts_infer-cli -c custom.toml f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml ``` -### 3. More instructions +### 3. Runtime -- In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer). -- The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue. +Deployment solution with Triton and TensorRT-LLM. + +#### Benchmark Results +Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs. + +| Model | Concurrency | Avg Latency | RTF | +|-------|-------------|----------------|-------| +| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394| + +See [detailed instructions](src\f5_tts\runtime\triton_trtllm\README.md) for more information. ## Training @@ -231,6 +242,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~ - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman) - [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ) +- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~ ## Citation If our work and codebase is useful for you, please cite as: diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index ea4b67f..90679be 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -270,7 +270,7 @@ class CFM(nn.Module): else: drop_text = False - # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here + # if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences pred = self.transformer( x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text diff --git a/src/f5_tts/runtime/triton_trtllm/README.md b/src/f5_tts/runtime/triton_trtllm/README.md index a25c9be..a202c18 100644 --- a/src/f5_tts/runtime/triton_trtllm/README.md +++ b/src/f5_tts/runtime/triton_trtllm/README.md @@ -1,4 +1,4 @@ -## Triton Inference Serving Best Practice for F5 TTS +## Triton Inference Serving Best Practice for F5-TTS ### Quick Start Directly launch the service using docker compose. @@ -21,14 +21,15 @@ docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm ### Export Models to TensorRT-LLM and Launch Server Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper). - ```sh bash run.sh 0 4 F5TTS_Base ``` + ### HTTP Client ```sh python3 client_http.py ``` + ### Benchmark using Dataset ```sh num_task=2 @@ -38,9 +39,9 @@ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_t ### Benchmark Results Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs. -| Model | Concurrency | Avg Latency | RTF | -|-------|-------------|-----------------|--| +| Model | Concurrency | Avg Latency | RTF | +|-------|-------------|----------------|-------| | F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394| ### Credits -1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) \ No newline at end of file +1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) diff --git a/src/f5_tts/runtime/triton_trtllm/patch/__init__.py b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py index d43cacc..445f9cc 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/__init__.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py @@ -13,10 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from .baichuan.model import BaichuanForCausalLM -from .bert.model import (BertForQuestionAnswering, - BertForSequenceClassification, BertModel, - RobertaForQuestionAnswering, - RobertaForSequenceClassification, RobertaModel) +from .bert.model import ( + BertForQuestionAnswering, + BertForSequenceClassification, + BertModel, + RobertaForQuestionAnswering, + RobertaForSequenceClassification, + RobertaModel, +) from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.config import ChatGLMConfig from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel @@ -46,8 +50,7 @@ from .mamba.model import MambaForCausalLM from .medusa.config import MedusaConfig from .medusa.model import MedusaForCausalLm from .mllama.model import MLLaMAModel -from .modeling_utils import (PretrainedConfig, PretrainedModel, - SpeculativeDecodingMode) +from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode from .mpt.model import MPTForCausalLM, MPTModel from .nemotron_nas.model import DeciLMForCausalLM from .opt.model import OPTForCausalLM, OPTModel @@ -59,138 +62,137 @@ from .redrafter.model import ReDrafterForCausalLM from .f5tts.model import F5TTS __all__ = [ - 'BertModel', - 'BertForQuestionAnswering', - 'BertForSequenceClassification', - 'RobertaModel', - 'RobertaForQuestionAnswering', - 'RobertaForSequenceClassification', - 'BloomModel', - 'BloomForCausalLM', - 'DiT', - 'DeepseekForCausalLM', - 'FalconConfig', - 'DeepseekV2ForCausalLM', - 'FalconForCausalLM', - 'FalconModel', - 'GPTConfig', - 'GPTModel', - 'GPTForCausalLM', - 'OPTForCausalLM', - 'OPTModel', - 'LLaMAConfig', - 'LLaMAForCausalLM', - 'LLaMAModel', - 'MedusaConfig', - 'MedusaForCausalLm', - 'ReDrafterForCausalLM', - 'GPTJConfig', - 'GPTJModel', - 'GPTJForCausalLM', - 'GPTNeoXModel', - 'GPTNeoXForCausalLM', - 'PhiModel', - 'PhiConfig', - 'Phi3Model', - 'Phi3Config', - 'PhiForCausalLM', - 'Phi3ForCausalLM', - 'ChatGLMConfig', - 'ChatGLMForCausalLM', - 'ChatGLMModel', - 'BaichuanForCausalLM', - 'QWenConfig' - 'QWenForCausalLM', - 'QWenModel', - 'EncoderModel', - 'DecoderModel', - 'PretrainedConfig', - 'PretrainedModel', - 'WhisperEncoder', - 'MambaForCausalLM', - 'MambaConfig', - 'MPTForCausalLM', - 'MPTModel', - 'SkyworkForCausalLM', - 'GemmaConfig', - 'GemmaForCausalLM', - 'DbrxConfig', - 'DbrxForCausalLM', - 'RecurrentGemmaForCausalLM', - 'CogVLMConfig', - 'CogVLMForCausalLM', - 'EagleForCausalLM', - 'SpeculativeDecodingMode', - 'CohereForCausalLM', - 'MLLaMAModel', - 'F5TTS', + "BertModel", + "BertForQuestionAnswering", + "BertForSequenceClassification", + "RobertaModel", + "RobertaForQuestionAnswering", + "RobertaForSequenceClassification", + "BloomModel", + "BloomForCausalLM", + "DiT", + "DeepseekForCausalLM", + "FalconConfig", + "DeepseekV2ForCausalLM", + "FalconForCausalLM", + "FalconModel", + "GPTConfig", + "GPTModel", + "GPTForCausalLM", + "OPTForCausalLM", + "OPTModel", + "LLaMAConfig", + "LLaMAForCausalLM", + "LLaMAModel", + "MedusaConfig", + "MedusaForCausalLm", + "ReDrafterForCausalLM", + "GPTJConfig", + "GPTJModel", + "GPTJForCausalLM", + "GPTNeoXModel", + "GPTNeoXForCausalLM", + "PhiModel", + "PhiConfig", + "Phi3Model", + "Phi3Config", + "PhiForCausalLM", + "Phi3ForCausalLM", + "ChatGLMConfig", + "ChatGLMForCausalLM", + "ChatGLMModel", + "BaichuanForCausalLM", + "QWenConfigQWenForCausalLM", + "QWenModel", + "EncoderModel", + "DecoderModel", + "PretrainedConfig", + "PretrainedModel", + "WhisperEncoder", + "MambaForCausalLM", + "MambaConfig", + "MPTForCausalLM", + "MPTModel", + "SkyworkForCausalLM", + "GemmaConfig", + "GemmaForCausalLM", + "DbrxConfig", + "DbrxForCausalLM", + "RecurrentGemmaForCausalLM", + "CogVLMConfig", + "CogVLMForCausalLM", + "EagleForCausalLM", + "SpeculativeDecodingMode", + "CohereForCausalLM", + "MLLaMAModel", + "F5TTS", ] MODEL_MAP = { - 'GPT2LMHeadModel': GPTForCausalLM, - 'GPT2LMHeadCustomModel': GPTForCausalLM, - 'GPTBigCodeForCausalLM': GPTForCausalLM, - 'Starcoder2ForCausalLM': GPTForCausalLM, - 'FuyuForCausalLM': GPTForCausalLM, - 'Kosmos2ForConditionalGeneration': GPTForCausalLM, - 'JAISLMHeadModel': GPTForCausalLM, - 'GPTForCausalLM': GPTForCausalLM, - 'NemotronForCausalLM': GPTForCausalLM, - 'OPTForCausalLM': OPTForCausalLM, - 'BloomForCausalLM': BloomForCausalLM, - 'RWForCausalLM': FalconForCausalLM, - 'FalconForCausalLM': FalconForCausalLM, - 'PhiForCausalLM': PhiForCausalLM, - 'Phi3ForCausalLM': Phi3ForCausalLM, - 'Phi3VForCausalLM': Phi3ForCausalLM, - 'Phi3SmallForCausalLM': Phi3ForCausalLM, - 'PhiMoEForCausalLM': Phi3ForCausalLM, - 'MambaForCausalLM': MambaForCausalLM, - 'GPTNeoXForCausalLM': GPTNeoXForCausalLM, - 'GPTJForCausalLM': GPTJForCausalLM, - 'MPTForCausalLM': MPTForCausalLM, - 'GLMModel': ChatGLMForCausalLM, - 'ChatGLMModel': ChatGLMForCausalLM, - 'ChatGLMForCausalLM': ChatGLMForCausalLM, - 'LlamaForCausalLM': LLaMAForCausalLM, - 'ExaoneForCausalLM': LLaMAForCausalLM, - 'MistralForCausalLM': LLaMAForCausalLM, - 'MixtralForCausalLM': LLaMAForCausalLM, - 'ArcticForCausalLM': LLaMAForCausalLM, - 'Grok1ModelForCausalLM': GrokForCausalLM, - 'InternLMForCausalLM': LLaMAForCausalLM, - 'InternLM2ForCausalLM': LLaMAForCausalLM, - 'MedusaForCausalLM': MedusaForCausalLm, - 'ReDrafterForCausalLM': ReDrafterForCausalLM, - 'BaichuanForCausalLM': BaichuanForCausalLM, - 'BaiChuanForCausalLM': BaichuanForCausalLM, - 'SkyworkForCausalLM': LLaMAForCausalLM, + "GPT2LMHeadModel": GPTForCausalLM, + "GPT2LMHeadCustomModel": GPTForCausalLM, + "GPTBigCodeForCausalLM": GPTForCausalLM, + "Starcoder2ForCausalLM": GPTForCausalLM, + "FuyuForCausalLM": GPTForCausalLM, + "Kosmos2ForConditionalGeneration": GPTForCausalLM, + "JAISLMHeadModel": GPTForCausalLM, + "GPTForCausalLM": GPTForCausalLM, + "NemotronForCausalLM": GPTForCausalLM, + "OPTForCausalLM": OPTForCausalLM, + "BloomForCausalLM": BloomForCausalLM, + "RWForCausalLM": FalconForCausalLM, + "FalconForCausalLM": FalconForCausalLM, + "PhiForCausalLM": PhiForCausalLM, + "Phi3ForCausalLM": Phi3ForCausalLM, + "Phi3VForCausalLM": Phi3ForCausalLM, + "Phi3SmallForCausalLM": Phi3ForCausalLM, + "PhiMoEForCausalLM": Phi3ForCausalLM, + "MambaForCausalLM": MambaForCausalLM, + "GPTNeoXForCausalLM": GPTNeoXForCausalLM, + "GPTJForCausalLM": GPTJForCausalLM, + "MPTForCausalLM": MPTForCausalLM, + "GLMModel": ChatGLMForCausalLM, + "ChatGLMModel": ChatGLMForCausalLM, + "ChatGLMForCausalLM": ChatGLMForCausalLM, + "LlamaForCausalLM": LLaMAForCausalLM, + "ExaoneForCausalLM": LLaMAForCausalLM, + "MistralForCausalLM": LLaMAForCausalLM, + "MixtralForCausalLM": LLaMAForCausalLM, + "ArcticForCausalLM": LLaMAForCausalLM, + "Grok1ModelForCausalLM": GrokForCausalLM, + "InternLMForCausalLM": LLaMAForCausalLM, + "InternLM2ForCausalLM": LLaMAForCausalLM, + "MedusaForCausalLM": MedusaForCausalLm, + "ReDrafterForCausalLM": ReDrafterForCausalLM, + "BaichuanForCausalLM": BaichuanForCausalLM, + "BaiChuanForCausalLM": BaichuanForCausalLM, + "SkyworkForCausalLM": LLaMAForCausalLM, GEMMA_ARCHITECTURE: GemmaForCausalLM, GEMMA2_ARCHITECTURE: GemmaForCausalLM, - 'QWenLMHeadModel': QWenForCausalLM, - 'QWenForCausalLM': QWenForCausalLM, - 'Qwen2ForCausalLM': QWenForCausalLM, - 'Qwen2MoeForCausalLM': QWenForCausalLM, - 'Qwen2ForSequenceClassification': QWenForCausalLM, - 'Qwen2VLForConditionalGeneration': QWenForCausalLM, - 'WhisperEncoder': WhisperEncoder, - 'EncoderModel': EncoderModel, - 'DecoderModel': DecoderModel, - 'DbrxForCausalLM': DbrxForCausalLM, - 'RecurrentGemmaForCausalLM': RecurrentGemmaForCausalLM, - 'CogVLMForCausalLM': CogVLMForCausalLM, - 'DiT': DiT, - 'DeepseekForCausalLM': DeepseekForCausalLM, - 'DeciLMForCausalLM': DeciLMForCausalLM, - 'DeepseekV2ForCausalLM': DeepseekV2ForCausalLM, - 'EagleForCausalLM': EagleForCausalLM, - 'CohereForCausalLM': CohereForCausalLM, - 'MllamaForConditionalGeneration': MLLaMAModel, - 'BertForQuestionAnswering': BertForQuestionAnswering, - 'BertForSequenceClassification': BertForSequenceClassification, - 'BertModel': BertModel, - 'RobertaModel': RobertaModel, - 'RobertaForQuestionAnswering': RobertaForQuestionAnswering, - 'RobertaForSequenceClassification': RobertaForSequenceClassification, - 'F5TTS': F5TTS + "QWenLMHeadModel": QWenForCausalLM, + "QWenForCausalLM": QWenForCausalLM, + "Qwen2ForCausalLM": QWenForCausalLM, + "Qwen2MoeForCausalLM": QWenForCausalLM, + "Qwen2ForSequenceClassification": QWenForCausalLM, + "Qwen2VLForConditionalGeneration": QWenForCausalLM, + "WhisperEncoder": WhisperEncoder, + "EncoderModel": EncoderModel, + "DecoderModel": DecoderModel, + "DbrxForCausalLM": DbrxForCausalLM, + "RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM, + "CogVLMForCausalLM": CogVLMForCausalLM, + "DiT": DiT, + "DeepseekForCausalLM": DeepseekForCausalLM, + "DeciLMForCausalLM": DeciLMForCausalLM, + "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, + "EagleForCausalLM": EagleForCausalLM, + "CohereForCausalLM": CohereForCausalLM, + "MllamaForConditionalGeneration": MLLaMAModel, + "BertForQuestionAnswering": BertForQuestionAnswering, + "BertForSequenceClassification": BertForSequenceClassification, + "BertModel": BertModel, + "RobertaModel": RobertaModel, + "RobertaForQuestionAnswering": RobertaForQuestionAnswering, + "RobertaForSequenceClassification": RobertaForSequenceClassification, + "F5TTS": F5TTS, }