diff --git a/src/f5_tts/runtime/triton_trtllm/Dockerfile.server b/src/f5_tts/runtime/triton_trtllm/Dockerfile.server new file mode 100644 index 0000000..861e266 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/Dockerfile.server @@ -0,0 +1,3 @@ +FROM nvcr.io/nvidia/tritonserver:24.12-py3 +RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos +WORKDIR /workspace \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/README.md b/src/f5_tts/runtime/triton_trtllm/README.md new file mode 100644 index 0000000..a25c9be --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/README.md @@ -0,0 +1,46 @@ +## Triton Inference Serving Best Practice for F5 TTS + +### Quick Start +Directly launch the service using docker compose. +```sh +# TODO: support F5TTS_v1_Base +MODEL=F5TTS_Base docker compose up +``` + +### Build Image +Build the docker image from scratch. +```sh +docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12 +``` + +### Create Docker Container +```sh +your_mount_dir=/mnt:/mnt +docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12 +``` + +### Export Models to TensorRT-LLM and Launch Server +Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper). + +```sh +bash run.sh 0 4 F5TTS_Base +``` +### HTTP Client +```sh +python3 client_http.py +``` +### Benchmark using Dataset +```sh +num_task=2 +python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts +``` + +### Benchmark Results +Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs. + +| Model | Concurrency | Avg Latency | RTF | +|-------|-------------|-----------------|--| +| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394| + +### Credits +1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/client_grpc.py b/src/f5_tts/runtime/triton_trtllm/client_grpc.py new file mode 100644 index 0000000..c4e2c43 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/client_grpc.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Nvidia (authors: Yuekai Zhang) +# 2023 Recurrent.ai (authors: Songtao Shi) +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script supports to load dataset from huggingface and sends it to the server +for decoding, in parallel. + +Usage: +num_task=2 + +# For offline F5-TTS +python3 client_grpc.py \ + --server-addr localhost \ + --model-name f5_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name test_zh \ + --log-dir ./log_concurrent_tasks_${num_task} + +# For offline Spark-TTS-0.5B +python3 client_grpc.py \ + --server-addr localhost \ + --model-name spark_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name wenetspeech4tts \ + --log-dir ./log_concurrent_tasks_${num_task} +""" + +import argparse +import asyncio +import json + +import os +import time +import types +from pathlib import Path + +import numpy as np +import soundfile as sf +import tritonclient +import tritonclient.grpc.aio as grpcclient +from tritonclient.utils import np_to_triton_dtype + + +def write_triton_stats(stats, summary_file): + with open(summary_file, "w") as summary_f: + model_stats = stats["model_stats"] + # write a note, the log is from triton_client.get_inference_statistics(), to better human readability + summary_f.write( + "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" + ) + summary_f.write("To learn more about the log, please refer to: \n") + summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") + summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") + summary_f.write( + "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" + ) + summary_f.write( + "However, there is a trade-off between the increased queue time and the increased batch size. \n" + ) + summary_f.write( + "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" + ) + summary_f.write( + "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" + ) + for model_state in model_stats: + if "last_inference" not in model_state: + continue + summary_f.write(f"model name is {model_state['name']} \n") + model_inference_stats = model_state["inference_stats"] + total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 + total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 + total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 + total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 + summary_f.write( + f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa + ) + model_batch_stats = model_state["batch_stats"] + for batch in model_batch_stats: + batch_size = int(batch["batch_size"]) + compute_input = batch["compute_input"] + compute_output = batch["compute_output"] + compute_infer = batch["compute_infer"] + batch_count = int(compute_infer["count"]) + assert compute_infer["count"] == compute_output["count"] == compute_input["count"] + compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 + compute_input_time_ms = int(compute_input["ns"]) / 1e6 + compute_output_time_ms = int(compute_output["ns"]) / 1e6 + summary_f.write( + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa + ) + summary_f.write( + f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa + ) + summary_f.write( + f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa + ) + + +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=8001, + help="Grpc port of the triton server, default is 8001", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default=None, + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--huggingface-dataset", + type=str, + default="yuekai/seed_tts", + help="dataset name in huggingface dataset hub", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="dataset split name, default is 'test'", + ) + + parser.add_argument( + "--manifest-path", + type=str, + default=None, + help="Path to the manifest dir which includes wav.scp trans.txt files.", + ) + + parser.add_argument( + "--model-name", + type=str, + default="f5_tts", + choices=["f5_tts", "spark_tts"], + help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + ) + + parser.add_argument( + "--num-tasks", + type=int, + default=1, + help="Number of concurrent tasks for sending", + ) + + parser.add_argument( + "--log-interval", + type=int, + default=5, + help="Controls how frequently we print the log.", + ) + + parser.add_argument( + "--compute-wer", + action="store_true", + default=False, + help="""True to compute WER. + """, + ) + + parser.add_argument( + "--log-dir", + type=str, + required=False, + default="./tmp", + help="log directory", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Inference batch_size per request for offline mode.", + ) + + return parser.parse_args() + + +def load_audio(wav_path, target_sample_rate=16000): + assert target_sample_rate == 16000, "hard coding in server" + if isinstance(wav_path, dict): + waveform = wav_path["array"] + sample_rate = wav_path["sampling_rate"] + else: + waveform, sample_rate = sf.read(wav_path) + if sample_rate != target_sample_rate: + from scipy.signal import resample + + num_samples = int(len(waveform) * (target_sample_rate / sample_rate)) + waveform = resample(waveform, num_samples) + return waveform, target_sample_rate + + +async def send( + manifest_item_list: list, + name: str, + triton_client: tritonclient.grpc.aio.InferenceServerClient, + protocol_client: types.ModuleType, + log_interval: int, + model_name: str, + padding_duration: int = None, + audio_save_dir: str = "./", + save_sample_rate: int = 16000, +): + total_duration = 0.0 + latency_data = [] + task_id = int(name[5:]) + + print(f"manifest_item_list: {manifest_item_list}") + for i, item in enumerate(manifest_item_list): + if i % log_interval == 0: + print(f"{name}: {i}/{len(manifest_item_list)}") + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + duration = len(waveform) / sample_rate + lengths = np.array([[len(waveform)]], dtype=np.int32) + + reference_text, target_text = item["reference_text"], item["target_text"] + + estimated_target_duration = duration / len(reference_text) * len(target_text) + + if padding_duration: + # padding to nearset 10 seconds + samples = np.zeros( + ( + 1, + padding_duration + * sample_rate + * ((int(estimated_target_duration + duration) // padding_duration) + 1), + ), + dtype=np.float32, + ) + + samples[0, : len(waveform)] = waveform + else: + samples = waveform + + samples = samples.reshape(1, -1).astype(np.float32) + + inputs = [ + protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), + protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)), + protocol_client.InferInput("reference_text", [1, 1], "BYTES"), + protocol_client.InferInput("target_text", [1, 1], "BYTES"), + ] + inputs[0].set_data_from_numpy(samples) + inputs[1].set_data_from_numpy(lengths) + + input_data_numpy = np.array([reference_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[2].set_data_from_numpy(input_data_numpy) + + input_data_numpy = np.array([target_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[3].set_data_from_numpy(input_data_numpy) + + outputs = [protocol_client.InferRequestedOutput("waveform")] + + sequence_id = 100000000 + i + task_id * 10 + start = time.time() + response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs) + + audio = response.as_numpy("waveform").reshape(-1) + + end = time.time() - start + + audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") + sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") + + latency_data.append((end, estimated_target_duration)) + total_duration += estimated_target_duration + + return total_duration, latency_data + + +def load_manifests(manifest_path): + with open(manifest_path, "r") as f: + manifest_list = [] + for line in f: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) + manifest_list.append( + { + "audio_filepath": prompt_wav, + "reference_text": prompt_text, + "target_text": gt_text, + "target_audio_path": utt, + } + ) + return manifest_list + + +def split_data(data, k): + n = len(data) + if n < k: + print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.") + k = n + + quotient = n // k + remainder = n % k + + result = [] + start = 0 + for i in range(k): + if i < remainder: + end = start + quotient + 1 + else: + end = start + quotient + + result.append(data[start:end]) + start = end + + return result + + +async def main(): + args = get_args() + url = f"{args.server_addr}:{args.server_port}" + + triton_client = grpcclient.InferenceServerClient(url=url, verbose=False) + protocol_client = grpcclient + + if args.reference_audio: + args.num_tasks = 1 + args.log_interval = 1 + manifest_item_list = [ + { + "reference_text": args.reference_text, + "target_text": args.target_text, + "audio_filepath": args.reference_audio, + "target_audio_path": "test", + } + ] + elif args.huggingface_dataset: + import datasets + + dataset = datasets.load_dataset( + args.huggingface_dataset, + split=args.split_name, + trust_remote_code=True, + ) + manifest_item_list = [] + for i in range(len(dataset)): + manifest_item_list.append( + { + "audio_filepath": dataset[i]["prompt_audio"], + "reference_text": dataset[i]["prompt_text"], + "target_audio_path": dataset[i]["id"], + "target_text": dataset[i]["target_text"], + } + ) + else: + manifest_item_list = load_manifests(args.manifest_path) + + args.num_tasks = min(args.num_tasks, len(manifest_item_list)) + manifest_item_list = split_data(manifest_item_list, args.num_tasks) + + os.makedirs(args.log_dir, exist_ok=True) + tasks = [] + start_time = time.time() + for i in range(args.num_tasks): + task = asyncio.create_task( + send( + manifest_item_list[i], + name=f"task-{i}", + triton_client=triton_client, + protocol_client=protocol_client, + log_interval=args.log_interval, + model_name=args.model_name, + audio_save_dir=args.log_dir, + padding_duration=1, + save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, + ) + ) + tasks.append(task) + + ans_list = await asyncio.gather(*tasks) + + end_time = time.time() + elapsed = end_time - start_time + + total_duration = 0.0 + latency_data = [] + for ans in ans_list: + total_duration += ans[0] + latency_data += ans[1] + + rtf = elapsed / total_duration + + s = f"RTF: {rtf:.4f}\n" + s += f"total_duration: {total_duration:.3f} seconds\n" + s += f"({total_duration / 3600:.2f} hours)\n" + s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" + + latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] + latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 + latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 + s += f"latency_variance: {latency_variance:.2f}\n" + s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n" + s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n" + s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n" + s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n" + s += f"average_latency_ms: {latency_ms:.2f}\n" + + print(s) + if args.manifest_path: + name = Path(args.manifest_path).stem + elif args.split_name: + name = args.split_name + with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: + f.write(s) + + stats = await triton_client.get_inference_statistics(model_name="", as_json=True) + write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + + metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True) + with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: + json.dump(metadata, f, indent=4) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/f5_tts/runtime/triton_trtllm/client_http.py b/src/f5_tts/runtime/triton_trtllm/client_http.py new file mode 100644 index 0000000..87b83d5 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/client_http.py @@ -0,0 +1,142 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import requests +import soundfile as sf +import numpy as np +import argparse + + +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--server-url", + type=str, + default="localhost:8000", + help="Address of the server", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default="../../infer/examples/basic/basic_ref_en.wav", + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="Some call me nature, others call me mother nature.", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.", + help="", + ) + + parser.add_argument( + "--model-name", + type=str, + default="f5_tts", + choices=["f5_tts", "spark_tts"], + help="triton model_repo module name to request", + ) + + parser.add_argument( + "--output-audio", + type=str, + default="output.wav", + help="Path to save the output audio", + ) + return parser.parse_args() + + +def prepare_request( + samples, + reference_text, + target_text, + sample_rate=16000, + audio_save_dir: str = "./", +): + assert len(samples.shape) == 1, "samples should be 1D" + lengths = np.array([[len(samples)]], dtype=np.int32) + samples = samples.reshape(1, -1).astype(np.float32) + + data = { + "inputs": [ + {"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()}, + { + "name": "reference_wav_len", + "shape": lengths.shape, + "datatype": "INT32", + "data": lengths.tolist(), + }, + {"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]}, + {"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]}, + ] + } + + return data + + +def load_audio(wav_path, target_sample_rate=16000): + assert target_sample_rate == 16000, "hard coding in server" + if isinstance(wav_path, dict): + samples = wav_path["array"] + sample_rate = wav_path["sampling_rate"] + else: + samples, sample_rate = sf.read(wav_path) + if sample_rate != target_sample_rate: + from scipy.signal import resample + + num_samples = int(len(samples) * (target_sample_rate / sample_rate)) + samples = resample(samples, num_samples) + return samples, target_sample_rate + + +if __name__ == "__main__": + args = get_args() + server_url = args.server_url + if not server_url.startswith(("http://", "https://")): + server_url = f"http://{server_url}" + + url = f"{server_url}/v2/models/{args.model_name}/infer" + samples, sr = load_audio(args.reference_audio) + assert sr == 16000, "sample rate hardcoded in server" + + samples = np.array(samples, dtype=np.float32) + data = prepare_request(samples, args.reference_text, args.target_text) + + rsp = requests.post( + url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"} + ) + result = rsp.json() + audio = result["outputs"][0]["data"] + audio = np.array(audio, dtype=np.float32) + sf.write(args.output_audio, audio, 24000, "PCM_16") diff --git a/src/f5_tts/runtime/triton_trtllm/docker-compose.yml b/src/f5_tts/runtime/triton_trtllm/docker-compose.yml new file mode 100644 index 0000000..1519591 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/docker-compose.yml @@ -0,0 +1,20 @@ +services: + tts: + image: soar97/triton-f5-tts:24.12 + shm_size: '1gb' + ports: + - "8000:8000" + - "8001:8001" + - "8002:8002" + environment: + - PYTHONIOENCODING=utf-8 + - MODEL_ID=${MODEL_ID} + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] + command: > + /bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL" diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py new file mode 100644 index 0000000..ecd12a6 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py @@ -0,0 +1,431 @@ +import tensorrt as trt +import os +import math +import time +from typing import List, Optional +from functools import wraps + +import tensorrt_llm +from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.runtime.session import Session + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def remove_tensor_padding(input_tensor, input_tensor_lengths=None): + # Audio tensor case: batch, seq_len, feature_len + # position_ids case: batch, seq_len + assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor" + + # Initialize a list to collect valid sequences + valid_sequences = [] + + for i in range(input_tensor.shape[0]): + valid_length = input_tensor_lengths[i] + valid_sequences.append(input_tensor[i, :valid_length]) + + # Concatenate all valid sequences along the batch dimension + output_tensor = torch.cat(valid_sequences, dim=0).contiguous() + return output_tensor + + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) + + def forward(self, text): + # only keep tensors with value not -1 + text_mask = text != -1 + text_pad_cut_off_index = text_mask.sum(dim=1).max() + + text = text[:, :text_pad_cut_off_index] + text = self.text_embed(text) + text = text + self.freqs_cis[: text.shape[1], :] + for block in self.text_blocks: + text = block(text) + # padding text to the original length + # text shape: B,seq_len,C + # pad at the second dimension + text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0) + return text + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def load_checkpoint(ckpt_path, use_ema=True): + checkpoint = torch.load(ckpt_path, weights_only=True) + if use_ema: + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + dict_state = checkpoint["model_state_dict"] + text_embed_dict = {} + for key in dict_state.keys(): + # transformer.text_embed.text_embed.weight -> text_embed.weight + if "text_embed" in key: + text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key] + return text_embed_dict + + +class F5TTS(object): + def __init__( + self, + config, + debug_mode=True, + stream: Optional[torch.cuda.Stream] = None, + tllm_model_dir: Optional[str] = None, + model_path: Optional[str] = None, + vocab_size: Optional[int] = None, + ): + self.dtype = config["pretrained_config"]["dtype"] + + rank = tensorrt_llm.mpi_rank() + world_size = config["pretrained_config"]["mapping"]["world_size"] + cp_size = config["pretrained_config"]["mapping"]["cp_size"] + tp_size = config["pretrained_config"]["mapping"]["tp_size"] + pp_size = config["pretrained_config"]["mapping"]["pp_size"] + assert pp_size == 1 + self.mapping = tensorrt_llm.Mapping( + world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1 + ) + + local_rank = rank % self.mapping.gpus_per_node + self.device = torch.device(f"cuda:{local_rank}") + + torch.cuda.set_device(self.device) + + self.stream = stream + if self.stream is None: + self.stream = torch.cuda.Stream(self.device) + torch.cuda.set_stream(self.stream) + + engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine") + logger.info(f"Loading engine from {engine_file}") + with open(engine_file, "rb") as f: + engine_buffer = f.read() + + assert engine_buffer is not None + + self.session = Session.from_serialized_engine(engine_buffer) + + self.debug_mode = debug_mode + + self.inputs = {} + self.outputs = {} + self.buffer_allocated = False + + expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"] + + found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)] + if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names): + logger.error( + f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}" + ) + logger.error( + f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}" + ) + logger.error(f"Expected tensor names: {expected_tensor_names}") + logger.error(f"Found tensor names: {found_tensor_names}") + raise RuntimeError("Tensor names in engine are not the same as expected.") + if self.debug_mode: + self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names)) + + self.max_mel_len = 4096 + self.text_embedding = TextEmbedding( + text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len + ).to(self.device) + self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True) + + self.target_audio_sample_rate = 24000 + self.target_rms = 0.15 # target rms for audio + self.n_fft = 1024 + self.win_length = 1024 + self.hop_length = 256 + self.n_mel_channels = 100 + # self.max_mel_len = 3000 + self.head_dim = 64 + self.base_rescale_factor = 1.0 + self.interpolation_factor = 1.0 + base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) + freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor + self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0) + self.rope_cos = self.freqs.cos().half() + self.rope_sin = self.freqs.sin().half() + self.nfe_steps = 16 + t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32) + time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t) + delta_t = torch.diff(time_step) + # WAR: hard coding 256 here + tmp_dim = 256 + time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32) + half_dim = tmp_dim // 2 + emb_factor = math.log(10000) / (half_dim - 1) + emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor) + for i in range(self.nfe_steps): + emb = time_step[i] * emb_factor + time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1) + self.time_expand = time_expand.to(self.device) + self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device) + + def _tensor_dtype(self, name): + # return torch dtype given tensor name for convenience + dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name)) + return dtype + + def _setup(self, batch_size, seq_len): + for i in range(self.session.engine.num_io_tensors): + name = self.session.engine.get_tensor_name(i) + if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: + shape = list(self.session.engine.get_tensor_shape(name)) + shape[0] = batch_size + shape[1] = seq_len + self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device) + + self.buffer_allocated = True + + def cuda_stream_guard(func): + """Sync external stream and set current stream to the one bound to the session. Reset on exit.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + external_stream = torch.cuda.current_stream() + if external_stream != self.stream: + external_stream.synchronize() + torch.cuda.set_stream(self.stream) + ret = func(self, *args, **kwargs) + if external_stream != self.stream: + self.stream.synchronize() + torch.cuda.set_stream(external_stream) + return ret + + return wrapper + + @cuda_stream_guard + def forward( + self, + noise: torch.Tensor, + cond: torch.Tensor, + time_expand: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + input_lengths: torch.Tensor, + delta_t: torch.Tensor, + use_perf: bool = False, + ): + if use_perf: + torch.cuda.nvtx.range_push("flow matching") + cfg_strength = 2.0 + batch_size = noise.shape[0] + half_batch = batch_size // 2 + noise_half = noise[:half_batch] # Store the initial half of noise + + input_type = str_dtype_to_torch(self.dtype) + + # Keep a copy of the initial tensors + cond = cond.to(input_type) + rope_cos = rope_cos.to(input_type) + rope_sin = rope_sin.to(input_type) + input_lengths = input_lengths.to(str_dtype_to_torch("int32")) + + # Instead of iteratively updating noise within a single model context, + # we'll do a single forward pass for each iteration with fresh context setup + for i in range(self.nfe_steps): + # Re-setup the buffers for clean execution + self._setup(batch_size, noise.shape[1]) + if not self.buffer_allocated: + raise RuntimeError("Buffer not allocated, please call setup first!") + + # Re-create combined noises for this iteration + current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type) + + # Get time step for this iteration + current_time = time_expand[:, i].to(input_type) + + # Create fresh input dictionary for this iteration + current_inputs = { + "noise": current_noise, + "cond": cond, + "time": current_time, + "rope_cos": rope_cos, + "rope_sin": rope_sin, + "input_lengths": input_lengths, + } + + # Update inputs and set shapes + self.inputs.clear() # Clear previous inputs + self.inputs.update(**current_inputs) + self.session.set_shapes(self.inputs) + + if use_perf: + torch.cuda.nvtx.range_push(f"execute {i}") + ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream) + assert ok, "Failed to execute model" + # self.session.context.execute_async_v3(self.stream.cuda_stream) + if use_perf: + torch.cuda.nvtx.range_pop() + # Process results + t_scale = delta_t[i].unsqueeze(0).to(input_type) + + # Extract predictions + pred_cond = self.outputs["denoised"][:half_batch] + pred_uncond = self.outputs["denoised"][half_batch:] + + # Apply classifier-free guidance with safeguards + guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength + # Calculate update for noise + noise_half = noise_half + guidance * t_scale + if use_perf: + torch.cuda.nvtx.range_pop() + return noise_half + + def sample( + self, + text_pad_sequence: torch.Tensor, + ref_mel_batch: torch.Tensor, + ref_mel_len_batch: torch.Tensor, + estimated_reference_target_mel_len: List[int], + remove_input_padding: bool = False, + use_perf: bool = False, + ): + if use_perf: + torch.cuda.nvtx.range_push("text embedding") + batch = text_pad_sequence.shape[0] + max_seq_len = ref_mel_batch.shape[1] + + text_pad_sequence_drop = torch.cat( + (text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0 + ) + + text_embedding_drop_list = [] + for i in range(batch + 1): + text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device))) + text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0) + + text_embedding = text_embedding_drop_condition[:-1] + # text_embedding_drop B,T,C batch should be the same + text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1) + + noise = torch.randn_like(ref_mel_batch).to(self.device) + rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1) + rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1) + + cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1) + cat_mel_text_drop = torch.cat( + ( + torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device), + text_embedding_drop, + ), + dim=-1, + ) + + time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous() + + # Convert estimated_reference_target_mel_len to tensor + input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32) + + # combine above along the batch dimension + inputs = { + "noise": torch.cat((noise, noise), dim=0).contiguous(), + "cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(), + "time_expand": time_expand, + "rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(), + "rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(), + "input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(), + "delta_t": self.delta_t, + } + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_push("remove input padding") + if remove_input_padding: + max_seq_len = inputs["cond"].shape[1] + inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"]) + inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"]) + # for time_expand, convert from B,D to B,T,D by repeat + inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1) + inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"]) + inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"]) + inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"]) + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_pop() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + if use_perf: + torch.cuda.nvtx.range_pop() + start_time = time.time() + denoised = self.forward(**inputs, use_perf=use_perf) + cost_time = time.time() - start_time + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_push("remove input padding output") + if remove_input_padding: + denoised_list = [] + start_idx = 0 + for i in range(batch): + denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]]) + start_idx += inputs["input_lengths"][i] + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_pop() + return denoised_list, cost_time + return denoised, cost_time diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py new file mode 100644 index 0000000..9265886 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py @@ -0,0 +1,275 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json +import torch +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +from torch.utils.dlpack import from_dlpack, to_dlpack +import torchaudio +import jieba +import triton_python_backend_utils as pb_utils +from pypinyin import Style, lazy_pinyin +import os +from f5_tts_trtllm import F5TTS + + +def get_tokenizer(vocab_file_path: str): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + with open(vocab_file_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + return vocab_char_map, vocab_size + + +def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): + final_reference_target_texts_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return "\u3100" <= c <= "\u9fff" # common chinese characters + + for text in reference_target_texts_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) + else: + char_list.append(c) + final_reference_target_texts_list.append(char_list) + + return final_reference_target_texts_list + + +def list_str_to_idx( + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, +): # noqa: F722 + list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style + return list_idx_tensors + + +class TritonPythonModel: + def initialize(self, args): + self.use_perf = True + self.device = torch.device("cuda") + self.target_audio_sample_rate = 24000 + self.target_rms = 0.15 # target rms for audio + self.n_fft = 1024 + self.win_length = 1024 + self.hop_length = 256 + self.n_mel_channels = 100 + self.max_mel_len = 3000 + self.head_dim = 64 + + parameters = json.loads(args["model_config"])["parameters"] + for key, value in parameters.items(): + parameters[key] = value["string_value"] + + self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"]) + self.reference_sample_rate = int(parameters["reference_audio_sample_rate"]) + self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate) + + self.tllm_model_dir = parameters["tllm_model_dir"] + config_file = os.path.join(self.tllm_model_dir, "config.json") + with open(config_file) as f: + config = json.load(f) + self.model = F5TTS( + config, + debug_mode=False, + tllm_model_dir=self.tllm_model_dir, + model_path=parameters["model_path"], + vocab_size=self.vocab_size, + ) + + self.vocoder = parameters["vocoder"] + assert self.vocoder in ["vocos", "bigvgan"] + if self.vocoder == "vocos": + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_audio_sample_rate, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + n_mels=self.n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(self.device) + self.compute_mel_fn = self.get_vocos_mel_spectrogram + elif self.vocoder == "bigvgan": + self.compute_mel_fn = self.get_bigvgan_mel_spectrogram + + def get_vocos_mel_spectrogram(self, waveform): + mel = self.mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel.transpose(1, 2) + + def forward_vocoder(self, mel): + mel = mel.to(torch.float32).contiguous().cpu() + input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)) + + inference_request = pb_utils.InferenceRequest( + model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0] + ) + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + else: + waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform") + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def execute(self, requests): + ( + reference_text_list, + target_text_list, + reference_target_texts_list, + estimated_reference_target_mel_len, + reference_mel_len, + ) = [], [], [], [], [] + mel_features_list = [] + if self.use_perf: + torch.cuda.nvtx.range_push("preprocess") + for request in requests: + wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav") + wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode("utf-8") + reference_text_list.append(reference_text) + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode("utf-8") + target_text_list.append(target_text) + + text = reference_text + target_text + reference_target_texts_list.append(text) + + wav = from_dlpack(wav_tensor.to_dlpack()) + wav_len = from_dlpack(wav_lens.to_dlpack()) + wav_len = wav_len.squeeze() + assert wav.shape[0] == 1, "Only support batch size 1 for now." + wav = wav[:, :wav_len] + + ref_rms = torch.sqrt(torch.mean(torch.square(wav))) + if ref_rms < self.target_rms: + wav = wav * self.target_rms / ref_rms + if self.reference_sample_rate != self.target_audio_sample_rate: + wav = self.resampler(wav) + wav = wav.to(self.device) + if self.use_perf: + torch.cuda.nvtx.range_push("compute_mel") + mel_features = self.compute_mel_fn(wav) + if self.use_perf: + torch.cuda.nvtx.range_pop() + mel_features_list.append(mel_features) + + reference_mel_len.append(mel_features.shape[1]) + estimated_reference_target_mel_len.append( + int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text))) + ) + + max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len) + + batch = len(requests) + mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device) + for i, mel in enumerate(mel_features_list): + mel_features[i, : mel.shape[1], :] = mel + + reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device) + + pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) + text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map) + + for i, item in enumerate(text_pad_sequence): + text_pad_sequence[i] = F.pad( + item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1 + ) + text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS + text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device) + text_pad_sequence = F.pad( + text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1 + ) + if self.use_perf: + torch.cuda.nvtx.range_pop() + + denoised, cost_time = self.model.sample( + text_pad_sequence, + mel_features, + reference_mel_len_tensor, + estimated_reference_target_mel_len, + remove_input_padding=False, + use_perf=self.use_perf, + ) + if self.use_perf: + torch.cuda.nvtx.range_push("vocoder") + + responses = [] + for i in range(batch): + ref_me_len = reference_mel_len[i] + estimated_mel_len = estimated_reference_target_mel_len[i] + denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2) + audio = self.forward_vocoder(denoised_one_item) + rms = torch.sqrt(torch.mean(torch.square(audio))) + if rms < self.target_rms: + audio = audio * self.target_rms / rms + + audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio]) + responses.append(inference_response) + if self.use_perf: + torch.cuda.nvtx.range_pop() + return responses diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt new file mode 100644 index 0000000..4663f7c --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "f5_tts" +backend: "python" +max_batch_size: 4 +dynamic_batching { + max_queue_delay_microseconds: 1000 +} +parameters [ + { + key: "vocab_file" + value: { string_value: "${vocab}"} + }, + { + key: "model_path", + value: {string_value:"${model}"} + }, + { + key: "tllm_model_dir", + value: {string_value:"${trtllm}"} + }, + { + key: "reference_audio_sample_rate", + value: {string_value:"16000"} + }, + { + key: "vocoder", + value: {string_value:"${vocoder}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + optional: True + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + optional: True + }, + { + name: "reference_text" + data_type: TYPE_STRING + dims: [1] + }, + { + name: "target_text" + data_type: TYPE_STRING + dims: [1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt new file mode 100644 index 0000000..9a30b52 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt @@ -0,0 +1,32 @@ +name: "vocoder" +backend: "tensorrt" +default_model_filename: "vocoder.plan" +max_batch_size: 4 + +input [ + { + name: "mel" + data_type: TYPE_FP32 + dims: [ 100, -1 ] + } +] + +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +dynamic_batching { + preferred_batch_size: [1, 2, 4] + max_queue_delay_microseconds: 1 +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/patch/__init__.py b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py new file mode 100644 index 0000000..d43cacc --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .baichuan.model import BaichuanForCausalLM +from .bert.model import (BertForQuestionAnswering, + BertForSequenceClassification, BertModel, + RobertaForQuestionAnswering, + RobertaForSequenceClassification, RobertaModel) +from .bloom.model import BloomForCausalLM, BloomModel +from .chatglm.config import ChatGLMConfig +from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel +from .cogvlm.config import CogVLMConfig +from .cogvlm.model import CogVLMForCausalLM +from .commandr.model import CohereForCausalLM +from .dbrx.config import DbrxConfig +from .dbrx.model import DbrxForCausalLM +from .deepseek_v1.model import DeepseekForCausalLM +from .deepseek_v2.model import DeepseekV2ForCausalLM +from .dit.model import DiT +from .eagle.model import EagleForCausalLM +from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder +from .falcon.config import FalconConfig +from .falcon.model import FalconForCausalLM, FalconModel +from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig +from .gemma.model import GemmaForCausalLM +from .gpt.config import GPTConfig +from .gpt.model import GPTForCausalLM, GPTModel +from .gptj.config import GPTJConfig +from .gptj.model import GPTJForCausalLM, GPTJModel +from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel +from .grok.model import GrokForCausalLM +from .llama.config import LLaMAConfig +from .llama.model import LLaMAForCausalLM, LLaMAModel +from .mamba.model import MambaForCausalLM +from .medusa.config import MedusaConfig +from .medusa.model import MedusaForCausalLm +from .mllama.model import MLLaMAModel +from .modeling_utils import (PretrainedConfig, PretrainedModel, + SpeculativeDecodingMode) +from .mpt.model import MPTForCausalLM, MPTModel +from .nemotron_nas.model import DeciLMForCausalLM +from .opt.model import OPTForCausalLM, OPTModel +from .phi3.model import Phi3ForCausalLM, Phi3Model +from .phi.model import PhiForCausalLM, PhiModel +from .qwen.model import QWenForCausalLM +from .recurrentgemma.model import RecurrentGemmaForCausalLM +from .redrafter.model import ReDrafterForCausalLM +from .f5tts.model import F5TTS + +__all__ = [ + 'BertModel', + 'BertForQuestionAnswering', + 'BertForSequenceClassification', + 'RobertaModel', + 'RobertaForQuestionAnswering', + 'RobertaForSequenceClassification', + 'BloomModel', + 'BloomForCausalLM', + 'DiT', + 'DeepseekForCausalLM', + 'FalconConfig', + 'DeepseekV2ForCausalLM', + 'FalconForCausalLM', + 'FalconModel', + 'GPTConfig', + 'GPTModel', + 'GPTForCausalLM', + 'OPTForCausalLM', + 'OPTModel', + 'LLaMAConfig', + 'LLaMAForCausalLM', + 'LLaMAModel', + 'MedusaConfig', + 'MedusaForCausalLm', + 'ReDrafterForCausalLM', + 'GPTJConfig', + 'GPTJModel', + 'GPTJForCausalLM', + 'GPTNeoXModel', + 'GPTNeoXForCausalLM', + 'PhiModel', + 'PhiConfig', + 'Phi3Model', + 'Phi3Config', + 'PhiForCausalLM', + 'Phi3ForCausalLM', + 'ChatGLMConfig', + 'ChatGLMForCausalLM', + 'ChatGLMModel', + 'BaichuanForCausalLM', + 'QWenConfig' + 'QWenForCausalLM', + 'QWenModel', + 'EncoderModel', + 'DecoderModel', + 'PretrainedConfig', + 'PretrainedModel', + 'WhisperEncoder', + 'MambaForCausalLM', + 'MambaConfig', + 'MPTForCausalLM', + 'MPTModel', + 'SkyworkForCausalLM', + 'GemmaConfig', + 'GemmaForCausalLM', + 'DbrxConfig', + 'DbrxForCausalLM', + 'RecurrentGemmaForCausalLM', + 'CogVLMConfig', + 'CogVLMForCausalLM', + 'EagleForCausalLM', + 'SpeculativeDecodingMode', + 'CohereForCausalLM', + 'MLLaMAModel', + 'F5TTS', +] + +MODEL_MAP = { + 'GPT2LMHeadModel': GPTForCausalLM, + 'GPT2LMHeadCustomModel': GPTForCausalLM, + 'GPTBigCodeForCausalLM': GPTForCausalLM, + 'Starcoder2ForCausalLM': GPTForCausalLM, + 'FuyuForCausalLM': GPTForCausalLM, + 'Kosmos2ForConditionalGeneration': GPTForCausalLM, + 'JAISLMHeadModel': GPTForCausalLM, + 'GPTForCausalLM': GPTForCausalLM, + 'NemotronForCausalLM': GPTForCausalLM, + 'OPTForCausalLM': OPTForCausalLM, + 'BloomForCausalLM': BloomForCausalLM, + 'RWForCausalLM': FalconForCausalLM, + 'FalconForCausalLM': FalconForCausalLM, + 'PhiForCausalLM': PhiForCausalLM, + 'Phi3ForCausalLM': Phi3ForCausalLM, + 'Phi3VForCausalLM': Phi3ForCausalLM, + 'Phi3SmallForCausalLM': Phi3ForCausalLM, + 'PhiMoEForCausalLM': Phi3ForCausalLM, + 'MambaForCausalLM': MambaForCausalLM, + 'GPTNeoXForCausalLM': GPTNeoXForCausalLM, + 'GPTJForCausalLM': GPTJForCausalLM, + 'MPTForCausalLM': MPTForCausalLM, + 'GLMModel': ChatGLMForCausalLM, + 'ChatGLMModel': ChatGLMForCausalLM, + 'ChatGLMForCausalLM': ChatGLMForCausalLM, + 'LlamaForCausalLM': LLaMAForCausalLM, + 'ExaoneForCausalLM': LLaMAForCausalLM, + 'MistralForCausalLM': LLaMAForCausalLM, + 'MixtralForCausalLM': LLaMAForCausalLM, + 'ArcticForCausalLM': LLaMAForCausalLM, + 'Grok1ModelForCausalLM': GrokForCausalLM, + 'InternLMForCausalLM': LLaMAForCausalLM, + 'InternLM2ForCausalLM': LLaMAForCausalLM, + 'MedusaForCausalLM': MedusaForCausalLm, + 'ReDrafterForCausalLM': ReDrafterForCausalLM, + 'BaichuanForCausalLM': BaichuanForCausalLM, + 'BaiChuanForCausalLM': BaichuanForCausalLM, + 'SkyworkForCausalLM': LLaMAForCausalLM, + GEMMA_ARCHITECTURE: GemmaForCausalLM, + GEMMA2_ARCHITECTURE: GemmaForCausalLM, + 'QWenLMHeadModel': QWenForCausalLM, + 'QWenForCausalLM': QWenForCausalLM, + 'Qwen2ForCausalLM': QWenForCausalLM, + 'Qwen2MoeForCausalLM': QWenForCausalLM, + 'Qwen2ForSequenceClassification': QWenForCausalLM, + 'Qwen2VLForConditionalGeneration': QWenForCausalLM, + 'WhisperEncoder': WhisperEncoder, + 'EncoderModel': EncoderModel, + 'DecoderModel': DecoderModel, + 'DbrxForCausalLM': DbrxForCausalLM, + 'RecurrentGemmaForCausalLM': RecurrentGemmaForCausalLM, + 'CogVLMForCausalLM': CogVLMForCausalLM, + 'DiT': DiT, + 'DeepseekForCausalLM': DeepseekForCausalLM, + 'DeciLMForCausalLM': DeciLMForCausalLM, + 'DeepseekV2ForCausalLM': DeepseekV2ForCausalLM, + 'EagleForCausalLM': EagleForCausalLM, + 'CohereForCausalLM': CohereForCausalLM, + 'MllamaForConditionalGeneration': MLLaMAModel, + 'BertForQuestionAnswering': BertForQuestionAnswering, + 'BertForSequenceClassification': BertForSequenceClassification, + 'BertModel': BertModel, + 'RobertaModel': RobertaModel, + 'RobertaForQuestionAnswering': RobertaForQuestionAnswering, + 'RobertaForSequenceClassification': RobertaForSequenceClassification, + 'F5TTS': F5TTS +} diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py new file mode 100644 index 0000000..26c8bc9 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -0,0 +1,225 @@ +from __future__ import annotations +import sys +import os + +import tensorrt as trt +from collections import OrderedDict +from ..._utils import str_dtype_to_trt +from ...plugin import current_all_reduce_helper +from ..modeling_utils import PretrainedConfig, PretrainedModel +from ...functional import Tensor, concat +from ...module import Module, ModuleList +from tensorrt_llm._common import default_net +from ...layers import Linear + +from .modules import ( + TimestepEmbedding, + ConvPositionEmbedding, + DiTBlock, + AdaLayerNormZero_Final, +) + +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(current_file_path) +sys.path.append(parent_dir) + + +class InputEmbedding(Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) + + def forward(self, x, cond): + x = self.proj(concat([x, cond], dim=-1)) + return self.conv_pos_embed(x) + x + + +class F5TTS(PretrainedModel): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.dtype = str_dtype_to_trt(config.dtype) + + self.time_embed = TimestepEmbedding(config.hidden_size) + self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) + + self.dim = config.hidden_size + self.depth = config.num_hidden_layers + self.transformer_blocks = ModuleList( + [ + DiTBlock( + dim=self.dim, + heads=config.num_attention_heads, + dim_head=config.dim_head, + ff_mult=config.ff_mult, + dropout=config.dropout, + ) + for _ in range(self.depth) + ] + ) + + self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = Linear(config.hidden_size, config.mel_dim) + + def forward( + self, + noise, # nosied input audio + cond, # masked cond audio + time, # time step + rope_cos, + rope_sin, + input_lengths, + scale=1.0, + ): + t = self.time_embed(time) + x = self.input_embed(noise, cond) + for block in self.transformer_blocks: + x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) + denoise = self.proj_out(self.norm_out(x, t)) + denoise.mark_output("denoised", self.dtype) + return denoise + + def prepare_inputs(self, **kwargs): + max_batch_size = kwargs["max_batch_size"] + batch_size_range = [2, 2, max_batch_size] + mel_size = 100 + max_seq_len = 3000 + num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size] + hidden_size = 512 + concat_feature_dim = mel_size + hidden_size + freq_embed_dim = 256 + head_dim = 64 + mapping = self.config.mapping + if mapping.tp_size > 1: + current_all_reduce_helper().set_workspace_tensor(mapping, 1) + if default_net().plugin_config.remove_input_padding: + noise = Tensor( + name="noise", + dtype=self.dtype, + shape=[-1, mel_size], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("n_mels", [mel_size]), + ] + ), + ) + cond = Tensor( + name="cond", + dtype=self.dtype, + shape=[-1, concat_feature_dim], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("embeded_length", [concat_feature_dim]), + ] + ), + ) + time = Tensor( + name="time", + dtype=self.dtype, + shape=[-1, freq_embed_dim], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("freq_dim", [freq_embed_dim]), + ] + ), + ) + rope_cos = Tensor( + name="rope_cos", + dtype=self.dtype, + shape=[-1, head_dim], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("head_dim", [head_dim]), + ] + ), + ) + rope_sin = Tensor( + name="rope_sin", + dtype=self.dtype, + shape=[-1, head_dim], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("head_dim", [head_dim]), + ] + ), + ) + + else: + noise = Tensor( + name="noise", + dtype=self.dtype, + shape=[-1, -1, mel_size], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("n_mels", [mel_size]), + ] + ), + ) + cond = Tensor( + name="cond", + dtype=self.dtype, + shape=[-1, -1, concat_feature_dim], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("embeded_length", [concat_feature_dim]), + ] + ), + ) + time = Tensor( + name="time", + dtype=self.dtype, + shape=[-1, freq_embed_dim], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("freq_dim", [freq_embed_dim]), + ] + ), + ) + rope_cos = Tensor( + name="rope_cos", + dtype=self.dtype, + shape=[-1, -1, head_dim], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("head_dim", [head_dim]), + ] + ), + ) + rope_sin = Tensor( + name="rope_sin", + dtype=self.dtype, + shape=[-1, -1, head_dim], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("head_dim", [head_dim]), + ] + ), + ) + input_lengths = Tensor( + name="input_lengths", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size", [batch_size_range])]), + ) + return { + "noise": noise, + "cond": cond, + "time": time, + "rope_cos": rope_cos, + "rope_sin": rope_sin, + "input_lengths": input_lengths, + } diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py new file mode 100644 index 0000000..a0051b4 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +import numpy as np +from tensorrt_llm._common import default_net +from ..._utils import trt_dtype_to_np, str_dtype_to_trt +from ...functional import ( + Tensor, + chunk, + concat, + constant, + expand, + shape, + silu, + slice, + permute, + expand_mask, + expand_dims_like, + unsqueeze, + matmul, + softmax, + squeeze, + cast, + gelu, +) +from ...functional import expand_dims, view, bert_attention +from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear +from ...module import Module + + +class FeedForward(Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.project_in = Linear(dim, inner_dim) + self.ff = Linear(inner_dim, dim_out) + + def forward(self, x): + return self.ff(gelu(self.project_in(x))) + + +class AdaLayerNormZero(Module): + def __init__(self, dim): + super().__init__() + + self.linear = Linear(dim, dim * 6) + self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1) + x = self.norm(x) + ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) + if default_net().plugin_config.remove_input_padding: + x = x * (ones + scale_msa) + shift_msa + else: + x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1) + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZero_Final(Module): + def __init__(self, dim): + super().__init__() + + self.linear = Linear(dim, dim * 2) + + self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(silu(emb)) + scale, shift = chunk(emb, 2, dim=1) + ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) + if default_net().plugin_config.remove_input_padding: + x = self.norm(x) * (ones + scale) + shift + else: + x = self.norm(x) * unsqueeze((ones + scale), 1) + x = x + unsqueeze(shift, 1) + return x + + +class ConvPositionEmbedding(Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) + self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) + self.mish = Mish() + + def forward(self, x, mask=None): # noqa: F722 + if default_net().plugin_config.remove_input_padding: + x = unsqueeze(x, 0) + x = permute(x, [0, 2, 1]) + x = self.mish(self.conv1d2(self.mish(self.conv1d1(x)))) + out = permute(x, [0, 2, 1]) + if default_net().plugin_config.remove_input_padding: + out = squeeze(out, 0) + return out + + +class Attention(Module): + def __init__( + self, + processor: AttnProcessor, + dim: int, + heads: int = 16, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.processor = processor + + self.dim = dim # hidden_size + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + self.attention_head_size = dim_head + self.context_dim = context_dim + self.context_pre_only = context_pre_only + self.tp_size = 1 + self.num_attention_heads = heads // self.tp_size + self.num_attention_kv_heads = heads // self.tp_size # 8 + self.dtype = str_dtype_to_trt("float32") + self.attention_hidden_size = self.attention_head_size * self.num_attention_heads + self.to_q = ColumnLinear( + dim, + self.tp_size * self.num_attention_heads * self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) + self.to_k = ColumnLinear( + dim, + self.tp_size * self.num_attention_heads * self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) + self.to_v = ColumnLinear( + dim, + self.tp_size * self.num_attention_heads * self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) + + if self.context_dim is not None: + self.to_k_c = Linear(context_dim, self.inner_dim) + self.to_v_c = Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = Linear(context_dim, self.inner_dim) + + self.to_out = RowLinear( + self.tp_size * self.num_attention_heads * self.attention_head_size, + dim, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = Linear(self.inner_dim, dim) + + def forward( + self, + x, # noised input x + rope_cos, + rope_sin, + input_lengths, + c=None, # context c + scale=1.0, + rope=None, + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope) + else: + return self.processor( + self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale + ) + + +def rotate_every_two_3dim(tensor: Tensor) -> Tensor: + shape_tensor = concat( + [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())] + ) + if default_net().plugin_config.remove_input_padding: + assert tensor.ndim() == 2 + x1 = slice(tensor, [0, 0], shape_tensor, [1, 2]) + x2 = slice(tensor, [0, 1], shape_tensor, [1, 2]) + x1 = expand_dims(x1, 2) + x2 = expand_dims(x2, 2) + zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + x2 = zero - x2 + x = concat([x2, x1], 2) + out = view(x, concat([shape(x, 0), shape(x, 1) * 2])) + else: + assert tensor.ndim() == 3 + + x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2]) + x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2]) + x1 = expand_dims(x1, 3) + x2 = expand_dims(x2, 3) + zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + x2 = zero - x2 + x = concat([x2, x1], 3) + out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2])) + + return out + + +def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin): + if default_net().plugin_config.remove_input_padding: + rot_dim = shape(rope_cos, -1) # 64 + new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64) + x_ = slice(x, [0, 0], new_t_shape, [1, 1]) + end_dim = shape(x, -1) - shape(rope_cos, -1) + new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960) + x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1]) + out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1) + else: + rot_dim = shape(rope_cos, 2) # 64 + new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64) + x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1]) + end_dim = shape(x, 2) - shape(rope_cos, 2) + new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960) + x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1]) + out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1) + return out + + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn, + x, # noised input x + rope_cos, + rope_sin, + input_lengths, + scale=1.0, + rope=None, + ) -> torch.FloatTensor: + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + # k,v,q all (2,1226,1024) + query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin) + key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin) + + # attention + inner_dim = key.shape[-1] + norm_factor = math.sqrt(attn.attention_head_size) + q_scaling = 1.0 / norm_factor + mask = None + if not default_net().plugin_config.remove_input_padding: + N = shape(x, 1) + B = shape(x, 0) + seq_len_2d = concat([1, N]) + max_position_embeddings = 4096 + # create position ids + position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0)) + tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d) + tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1 + tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL + mask = tmp_position_ids < tmp_input_lengths # BxL + mask = mask.cast("int32") + + if default_net().plugin_config.bert_attention_plugin: + qkv = concat([query, key, value], dim=-1) + # TRT plugin mode + assert input_lengths is not None + if default_net().plugin_config.remove_input_padding: + qkv = qkv.view(concat([-1, 3 * inner_dim])) + max_input_length = constant( + np.zeros( + [ + 2048, + ], + dtype=np.int32, + ) + ) + else: + max_input_length = None + context = bert_attention( + qkv, + input_lengths, + attn.num_attention_heads, + attn.attention_head_size, + q_scaling=q_scaling, + max_input_length=max_input_length, + ) + else: + assert not default_net().plugin_config.remove_input_padding + + def transpose_for_scores(x): + new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) + + y = x.view(new_x_shape) + y = y.transpose(1, 2) + return y + + def transpose_for_scores_k(x): + new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) + + y = x.view(new_x_shape) + y = y.permute([0, 2, 3, 1]) + return y + + query = transpose_for_scores(query) + key = transpose_for_scores_k(key) + value = transpose_for_scores(value) + + attention_scores = matmul(query, key, use_fp32_acc=False) + + if mask is not None: + attention_mask = expand_mask(mask, shape(query, 2)) + attention_mask = cast(attention_mask, attention_scores.dtype) + attention_scores = attention_scores + attention_mask + + attention_probs = softmax(attention_scores, dim=-1) + + context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2) + context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size])) + context = attn.to_out(context) + if mask is not None: + mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1])) + mask = expand_dims_like(mask, context) + mask = cast(mask, context.dtype) + context = context * mask + return context + + +# DiT Block +class DiTBlock(Module): + def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1): + super().__init__() + + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention( + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout) + + def forward( + self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + # attention + # norm ----> (2,1226,1024) + attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) + + # process attention output for input x + if default_net().plugin_config.remove_input_padding: + x = x + gate_msa * attn_output + else: + x = x + unsqueeze(gate_msa, 1) * attn_output + ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) + if default_net().plugin_config.remove_input_padding: + norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp + else: + norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1) + # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp + ff_output = self.ff(norm) + if default_net().plugin_config.remove_input_padding: + x = x + gate_mlp * ff_output + else: + x = x + unsqueeze(gate_mlp, 1) * ff_output + + return x + + +class TimestepEmbedding(Module): + def __init__(self, dim, freq_embed_dim=256, dtype=None): + super().__init__() + # self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype) + self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype) + + def forward(self, timestep): + t_freq = self.mlp1(timestep) + t_freq = silu(t_freq) + t_emb = self.mlp2(t_freq) + return t_emb diff --git a/src/f5_tts/runtime/triton_trtllm/run.sh b/src/f5_tts/runtime/triton_trtllm/run.sh new file mode 100644 index 0000000..88e0d68 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/run.sh @@ -0,0 +1,70 @@ +stage=$1 +stop_stage=$2 +model=$3 # F5TTS_Base +if [ -z "$model" ]; then + echo "Model is none" + exit 1 +fi +echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model" +export CUDA_VISIBLE_DEVICES=0 + +F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS +F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt +F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine + +vocoder_trt_engine_path=vocos_vocoder.plan +model_repo=./model_repo + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo "Downloading f5 tts from huggingface" + huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH + +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + echo "Converting checkpoint" + python3 ./scripts/convert_checkpoint.py \ + --timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \ + --output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model + python_package_path=/usr/local/lib/python3.12/dist-packages + cp -r patch/* $python_package_path/tensorrt_llm/models + trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \ + --max_batch_size 8 \ + --output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + echo "Exporting vocos vocoder" + onnx_vocoder_path=vocos_vocoder.onnx + python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path + bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "Building triton server" + rm -r $model_repo + cp -r ./model_repo_f5_tts $model_repo + python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos + cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + echo "Starting triton server" + tritonserver --model-repository=$model_repo +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + echo "Testing triton server" + num_task=1 + log_dir=./log_concurrent_tasks_${num_task} + rm -r $log_dir + python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "Testing http client" + audio=../../infer/examples/basic/basic_ref_en.wav + reference_text="Some call me nature, others call me mother nature." + target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." + python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" +fi diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py new file mode 100644 index 0000000..563ba84 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py @@ -0,0 +1,247 @@ +# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License + +# Copyright (c) 2020 Shimin Zhang + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch as th +import torch.nn.functional as F +from scipy.signal import check_COLA, get_window + +support_clp_op = None +if th.__version__ >= "1.7.0": + from torch.fft import rfft as fft + + support_clp_op = True +else: + from torch import rfft as fft + + +class STFT(th.nn.Module): + def __init__( + self, + win_len=1024, + win_hop=512, + fft_len=1024, + enframe_mode="continue", + win_type="hann", + win_sqrt=False, + pad_center=True, + ): + """ + Implement of STFT using 1D convolution and 1D transpose convolutions. + Implement of framing the signal in 2 ways, `break` and `continue`. + `break` method is a kaldi-like framing. + `continue` method is a librosa-like framing. + + More information about `perfect reconstruction`: + 1. https://ww2.mathworks.cn/help/signal/ref/stft.html + 2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html + + Args: + win_len (int): Number of points in one frame. Defaults to 1024. + win_hop (int): Number of framing stride. Defaults to 512. + fft_len (int): Number of DFT points. Defaults to 1024. + enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'. + win_type (str, optional): The type of window to create. Defaults to 'hann'. + win_sqrt (bool, optional): using square root window. Defaults to True. + pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True. + """ + super(STFT, self).__init__() + assert enframe_mode in ["break", "continue"] + assert fft_len >= win_len + self.win_len = win_len + self.win_hop = win_hop + self.fft_len = fft_len + self.mode = enframe_mode + self.win_type = win_type + self.win_sqrt = win_sqrt + self.pad_center = pad_center + self.pad_amount = self.fft_len // 2 + + en_k, fft_k, ifft_k, ola_k = self.__init_kernel__() + self.register_buffer("en_k", en_k) + self.register_buffer("fft_k", fft_k) + self.register_buffer("ifft_k", ifft_k) + self.register_buffer("ola_k", ola_k) + + def __init_kernel__(self): + """ + Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel. + ** enframe_kernel: Using conv1d layer and identity matrix. + ** fft_kernel: Using linear layer for matrix multiplication. In fact, + enframe_kernel and fft_kernel can be combined, But for the sake of + readability, I took the two apart. + ** ifft_kernel, pinv of fft_kernel. + ** overlap-add kernel, just like enframe_kernel, but transposed. + + Returns: + tuple: four kernels. + """ + enframed_kernel = th.eye(self.fft_len)[:, None, :] + if support_clp_op: + tmp = fft(th.eye(self.fft_len)) + fft_kernel = th.stack([tmp.real, tmp.imag], dim=2) + else: + fft_kernel = fft(th.eye(self.fft_len), 1) + if self.mode == "break": + enframed_kernel = th.eye(self.win_len)[:, None, :] + fft_kernel = fft_kernel[: self.win_len] + fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1) + ifft_kernel = th.pinverse(fft_kernel)[:, None, :] + window = get_window(self.win_type, self.win_len) + + self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop) + window = th.FloatTensor(window) + if self.mode == "continue": + left_pad = (self.fft_len - self.win_len) // 2 + right_pad = left_pad + (self.fft_len - self.win_len) % 2 + window = F.pad(window, (left_pad, right_pad)) + if self.win_sqrt: + self.padded_window = window + window = th.sqrt(window) + else: + self.padded_window = window**2 + + fft_kernel = fft_kernel.T * window + ifft_kernel = ifft_kernel * window + ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :] + if self.mode == "continue": + ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len] + return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel + + def is_perfect(self): + """ + Whether the parameters win_len, win_hop and win_sqrt + obey constants overlap-add(COLA) + + Returns: + bool: Return true if parameters obey COLA. + """ + return self.perfect_reconstruct and self.pad_center + + def transform(self, inputs, return_type="complex"): + """Take input data (audio) to STFT domain. + + Args: + inputs (tensor): Tensor of floats, with shape (num_batch, num_samples) + return_type (str, optional): return (mag, phase) when `magphase`, + return (real, imag) when `realimag` and complex(real, imag) when `complex`. + Defaults to 'complex'. + + Returns: + tuple: (mag, phase) when `magphase`, return (real, imag) when + `realimag`. Defaults to 'complex', each elements with shape + [num_batch, num_frequencies, num_frames] + """ + assert return_type in ["magphase", "realimag", "complex"] + if inputs.dim() == 2: + inputs = th.unsqueeze(inputs, 1) + self.num_samples = inputs.size(-1) + if self.pad_center: + inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect") + enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop) + outputs = th.transpose(enframe_inputs, 1, 2) + outputs = F.linear(outputs, self.fft_k) + outputs = th.transpose(outputs, 1, 2) + dim = self.fft_len // 2 + 1 + real = outputs[:, :dim, :] + imag = outputs[:, dim:, :] + if return_type == "realimag": + return real, imag + elif return_type == "complex": + assert support_clp_op + return th.complex(real, imag) + else: + mags = th.sqrt(real**2 + imag**2) + phase = th.atan2(imag, real) + return mags, phase + + def inverse(self, input1, input2=None, input_type="magphase"): + """Call the inverse STFT (iSTFT), given tensors produced + by the `transform` function. + + Args: + input1 (tensors): Magnitude/Real-part of STFT with shape + [num_batch, num_frequencies, num_frames] + input2 (tensors): Phase/Imag-part of STFT with shape + [num_batch, num_frequencies, num_frames] + input_type (str, optional): Mathematical meaning of input tensor's. + Defaults to 'magphase'. + + Returns: + tensors: Reconstructed audio given magnitude and phase. Of + shape [num_batch, num_samples] + """ + assert input_type in ["magphase", "realimag"] + if input_type == "realimag": + real, imag = None, None + if support_clp_op and th.is_complex(input1): + real, imag = input1.real, input1.imag + else: + real, imag = input1, input2 + else: + real = input1 * th.cos(input2) + imag = input1 * th.sin(input2) + inputs = th.cat([real, imag], dim=1) + outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop) + t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1)) + t = t.to(inputs.device) + coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop) + + num_frames = input1.size(-1) + num_samples = num_frames * self.win_hop + + rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples + + outputs = outputs[..., rm_start:rm_end] + coff = coff[..., rm_start:rm_end] + coffidx = th.where(coff > 1e-8) + outputs[coffidx] = outputs[coffidx] / (coff[coffidx]) + return outputs.squeeze(dim=1) + + def forward(self, inputs): + """Take input data (audio) to STFT domain and then back to audio. + + Args: + inputs (tensor): Tensor of floats, with shape [num_batch, num_samples] + + Returns: + tensor: Reconstructed audio given magnitude and phase. + Of shape [num_batch, num_samples] + """ + mag, phase = self.transform(inputs) + rec_wav = self.inverse(mag, phase) + return rec_wav diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py b/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py new file mode 100644 index 0000000..22dad65 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -0,0 +1,359 @@ +import argparse +import json +import os +import re +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +import safetensors.torch +import torch + +from tensorrt_llm import str_dtype_to_torch +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.convert_utils import split, split_matrix_tp + + +def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank): + split_v = split(v, tensor_parallel, rank, dim=1) + return split_v.contiguous() + + +def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): + split_v = split(v, tensor_parallel, rank, dim=0) + return split_v.contiguous() + + +FACEBOOK_DIT_NAME_MAPPING = { + "^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight", + "^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias", + "^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight", + "^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias", + "^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight", + "^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias", + "^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight", + "^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias", + "^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight", + "^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias", + "^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight", + "^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias", + "^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight", + "^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias", + "^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight", + "^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias", + "^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight", + "^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias", + "^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight", + "^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias", + "^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight", + "^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias", + "^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight", + "^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias", + "^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight", + "^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias", + "^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight", + "^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias", + "^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight", + "^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias", + "^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight", + "^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias", + "^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight", + "^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias", + "^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight", + "^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias", + "^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight", + "^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias", + "^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight", + "^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias", + "^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight", + "^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias", + "^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight", + "^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias", + "^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight", + "^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias", + "^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight", + "^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias", + "^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight", + "^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias", + "^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight", + "^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias", + "^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight", + "^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias", + "^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight", + "^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias", + "^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight", + "^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias", + "^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight", + "^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias", + "^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight", + "^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias", + "^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight", + "^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias", + "^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight", + "^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias", + "^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight", + "^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias", + "^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight", + "^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias", + "^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight", + "^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias", + "^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight", + "^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias", + "^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight", + "^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias", + "^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight", + "^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias", + "^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight", + "^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias", + "^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight", + "^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias", + "^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight", + "^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias", + "^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight", + "^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias", + "^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight", + "^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias", + "^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight", + "^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias", + "^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight", + "^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias", + "^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight", + "^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias", + "^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight", + "^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias", + "^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight", + "^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias", + "^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight", + "^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias", + "^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight", + "^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias", + "^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight", + "^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias", + "^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight", + "^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias", + "^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight", + "^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias", + "^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight", + "^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias", + "^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight", + "^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias", + "^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight", + "^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias", + "^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight", + "^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias", + "^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight", + "^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias", + "^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight", + "^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias", + "^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight", + "^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias", + "^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight", + "^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias", + "^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight", + "^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias", + "^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight", + "^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias", + "^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight", + "^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias", + "^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight", + "^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias", + "^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight", + "^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias", + "^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight", + "^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias", + "^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight", + "^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias", + "^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight", + "^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias", +} + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="F5TTS_Base", + choices=[ + "F5TTS_Base", + ], + ) # TODO: support F5TTS_v1_Base + parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt") + parser.add_argument( + "--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint" + ) + parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT") + parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers") + parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module") + parser.add_argument("--cfg_scale", type=float, default=4.0) + parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size") + parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size") + parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size") + parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers") + parser.add_argument( + "--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel" + ) + args = parser.parse_args() + return args + + +def convert_timm_dit(args, mapping, dtype="float32"): + weights = {} + tik = time.time() + torch_dtype = str_dtype_to_torch(dtype) + tensor_parallel = mapping.tp_size + + model_params = dict(torch.load(args.timm_ckpt)) + model_params = { + k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer") + } + prefix = "ema_model.transformer." + model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()} + + timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING + + def get_trtllm_name(timm_name): + for k, v in timm_to_trtllm_name.items(): + m = re.match(k, timm_name) + if m is not None: + if "*" in v: + v = v.replace("*", m.groups()[0]) + return v + return timm_name + + weights = dict() + for name, param in model_params.items(): + if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight": + weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1) + else: + weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype) + + assert len(weights) == len(model_params) + + # new_prefix = 'f5_transformer.' + new_prefix = "" + weights = {new_prefix + key: value for key, value in weights.items()} + import math + + scale_factor = math.pow(64, -0.25) + for k, v in weights.items(): + if re.match("^transformer_blocks.*.attn.to_k.weight$", k): + weights[k] *= scale_factor + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + + elif re.match("^transformer_blocks.*.attn.to_k.bias$", k): + weights[k] *= scale_factor + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + + elif re.match("^transformer_blocks.*.attn.to_q.weight$", k): + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] *= scale_factor + + elif re.match("^transformer_blocks.*.attn.to_q.bias$", k): + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + weights[k] *= scale_factor + + elif re.match("^transformer_blocks.*.attn.to_v.weight$", k): + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + + elif re.match("^transformer_blocks.*.attn.to_v.bias$", k): + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) + + elif re.match("^transformer_blocks.*.attn.to_out.weight$", k): + weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1) + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + print(f"Weights loaded. Total time: {t}") + return weights + + +def save_config(args): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + config = { + "architecture": "F5TTS", + "dtype": args.dtype, + "hidden_size": 1024, + "num_hidden_layers": 22, + "num_attention_heads": 16, + "dim_head": 64, + "dropout": 0.1, + "ff_mult": 2, + "mel_dim": 100, + "text_num_embeds": 256, + "text_dim": 512, + "conv_layers": 4, + "long_skip_connection": False, + "mapping": { + "world_size": args.cp_size * args.tp_size * args.pp_size, + "cp_size": args.cp_size, + "tp_size": args.tp_size, + "pp_size": args.pp_size, + }, + } + if args.fp8_linear: + config["quantization"] = { + "quant_algo": "FP8", + # TODO: add support for exclude modules. + # 'exclude_modules': "*final_layer*", + } + + with open(os.path.join(args.output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=4) + + +def covert_and_save(args, rank): + if rank == 0: + save_config(args) + + mapping = Mapping( + world_size=args.cp_size * args.tp_size * args.pp_size, + rank=rank, + cp_size=args.cp_size, + tp_size=args.tp_size, + pp_size=args.pp_size, + ) + + weights = convert_timm_dit(args, mapping, dtype=args.dtype) + + safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors")) + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log." + + +def main(): + args = parse_arguments() + world_size = args.cp_size * args.tp_size * args.pp_size + + assert args.pp_size == 1, "PP is not supported yet." + + tik = time.time() + if args.timm_ckpt is None: + return + print("start execute") + execute(args.workers, [covert_and_save] * world_size, args) + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + print(f"Total time of converting checkpoints: {t}") + + +if __name__ == "__main__": + main() diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py new file mode 100644 index 0000000..d94f0d7 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download + +from conv_stft import STFT +from vocos import Vocos +import argparse + +opset_version = 17 + + +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--vocoder", + type=str, + default="vocos", + choices=["vocos", "bigvgan"], + help="Vocoder to export", + ) + parser.add_argument( + "--output-path", + type=str, + default="./vocos_vocoder.onnx", + help="Output path", + ) + return parser.parse_args() + + +class ISTFTHead(nn.Module): + def __init__(self, n_fft: int, hop_length: int): + super().__init__() + self.out = None + self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft) + + def forward(self, x: torch.Tensor): + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) + real = mag * torch.cos(p) + imag = mag * torch.sin(p) + audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag") + return audio + + +class VocosVocoder(nn.Module): + def __init__(self, vocos_vocoder): + super(VocosVocoder, self).__init__() + self.vocos_vocoder = vocos_vocoder + istft_head_out = self.vocos_vocoder.head.out + n_fft = self.vocos_vocoder.head.istft.n_fft + hop_length = self.vocos_vocoder.head.istft.hop_length + istft_head_for_export = ISTFTHead(n_fft, hop_length) + istft_head_for_export.out = istft_head_out + self.vocos_vocoder.head = istft_head_for_export + + def forward(self, mel): + waveform = self.vocos_vocoder.decode(mel) + return waveform + + +def export_VocosVocoder(vocos_vocoder, output_path, verbose): + vocos_vocoder = VocosVocoder(vocos_vocoder).cuda() + vocos_vocoder.eval() + + dummy_batch_size = 8 + dummy_input_length = 500 + + dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda() + + with torch.no_grad(): + dummy_waveform = vocos_vocoder(mel=dummy_mel) + print(dummy_waveform.shape) + + dummy_input = dummy_mel + + torch.onnx.export( + vocos_vocoder, + dummy_input, + output_path, + opset_version=opset_version, + do_constant_folding=True, + input_names=["mel"], + output_names=["waveform"], + dynamic_axes={ + "mel": {0: "batch_size", 2: "input_length"}, + "waveform": {0: "batch_size", 1: "output_length"}, + }, + verbose=verbose, + ) + + print("Exported to {}".format(output_path)) + + +def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None): + if vocoder_name == "vocos": + # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) + if is_local: + print(f"Load vocos from local path {local_path}") + config_path = f"{local_path}/config.yaml" + model_path = f"{local_path}/pytorch_model.bin" + else: + print("Download Vocos from huggingface charactr/vocos-mel-24khz") + repo_id = "charactr/vocos-mel-24khz" + config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") + model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") + vocoder = Vocos.from_hparams(config_path) + state_dict = torch.load(model_path, map_location="cpu", weights_only=True) + vocoder.load_state_dict(state_dict) + vocoder = vocoder.eval().to(device) + elif vocoder_name == "bigvgan": + raise NotImplementedError("BigVGAN is not supported yet") + vocoder.remove_weight_norm() + vocoder = vocoder.eval().to(device) + return vocoder + + +if __name__ == "__main__": + args = get_args() + vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None) + if args.vocoder == "vocos": + export_VocosVocoder(vocoder, args.output_path, verbose=False) diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh new file mode 100644 index 0000000..2702275 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +TRTEXEC="/usr/src/tensorrt/bin/trtexec" + +ONNX_PATH=$1 +ENGINE_PATH=$2 +echo "ONNX_PATH: $ONNX_PATH" +echo "ENGINE_PATH: $ENGINE_PATH" +PRECISION="fp32" + + +MIN_BATCH_SIZE=1 +OPT_BATCH_SIZE=1 +MAX_BATCH_SIZE=8 + +MIN_INPUT_LENGTH=1 +OPT_INPUT_LENGTH=1000 +MAX_INPUT_LENGTH=3000 + +MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}" +MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}" +MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}" + +${TRTEXEC} \ + --minShapes="mel:${MEL_MIN_SHAPE}" \ + --optShapes="mel:${MEL_OPT_SHAPE}" \ + --maxShapes="mel:${MEL_MAX_SHAPE}" \ + --onnx=${ONNX_PATH} \ + --saveEngine=${ENGINE_PATH} + diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py b/src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py new file mode 100644 index 0000000..105cfac --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py @@ -0,0 +1,36 @@ +#! /usr/bin/env python3 +from argparse import ArgumentParser +from string import Template + + +def main(file_path, substitutions, in_place, participant_ids): + with open(file_path) as f: + pbtxt = Template(f.read()) + + sub_dict = {"max_queue_size": 0} + sub_dict["participant_ids"] = participant_ids + for sub in substitutions.split(","): + key, value = sub.split(":") + sub_dict[key] = value + + pbtxt = pbtxt.safe_substitute(sub_dict) + + if in_place: + with open(file_path, "w") as f: + f.write(pbtxt) + else: + print(pbtxt) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("file_path", help="path of the .pbtxt to modify") + parser.add_argument( + "substitutions", + help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...", + ) + parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place") + parser.add_argument("--participant_ids", help="Participant IDs for the model", default="") + args = parser.parse_args() + + main(**vars(args))