From 2724f9f101505613a8a5205864b03e69aaaf2288 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Wed, 2 Apr 2025 19:04:45 -0700 Subject: [PATCH 1/7] add Nvidia Triton TensorRT-LLM solution --- .../runtime/triton_trtllm/Dockerfile.server | 9 + src/f5_tts/runtime/triton_trtllm/README.md | 44 ++ .../runtime/triton_trtllm/client_grpc.py | 491 +++++++++++++++++ .../runtime/triton_trtllm/docker-compose.yml | 20 + .../runtime/triton_trtllm/fill_template.py | 42 ++ .../f5_tts/1/f5_tts_trtllm.py | 431 +++++++++++++++ .../model_repo_f5_tts/f5_tts/1/model.py | 271 ++++++++++ .../model_repo_f5_tts/f5_tts/config.pbtxt | 81 +++ .../model_repo_f5_tts/vocoder/1/.gitkeep | 0 .../model_repo_f5_tts/vocoder/config.pbtxt | 32 ++ .../runtime/triton_trtllm/patch/__init__.py | 196 +++++++ .../triton_trtllm/patch/f5tts/model.py | 254 +++++++++ .../triton_trtllm/patch/f5tts/modules.py | 499 ++++++++++++++++++ src/f5_tts/runtime/triton_trtllm/run.sh | 62 +++ .../triton_trtllm/scripts/conv_stft.py | 243 +++++++++ .../scripts/convert_checkpoint.py | 393 ++++++++++++++ .../scripts/export_vocoder_to_onnx.py | 144 +++++ .../triton_trtllm/scripts/export_vocos_trt.sh | 43 ++ 18 files changed, 3255 insertions(+) create mode 100644 src/f5_tts/runtime/triton_trtllm/Dockerfile.server create mode 100644 src/f5_tts/runtime/triton_trtllm/README.md create mode 100644 src/f5_tts/runtime/triton_trtllm/client_grpc.py create mode 100644 src/f5_tts/runtime/triton_trtllm/docker-compose.yml create mode 100644 src/f5_tts/runtime/triton_trtllm/fill_template.py create mode 100644 src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py create mode 100644 src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py create mode 100644 src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt create mode 100644 src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep create mode 100644 src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt create mode 100644 src/f5_tts/runtime/triton_trtllm/patch/__init__.py create mode 100644 src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py create mode 100644 src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py create mode 100644 src/f5_tts/runtime/triton_trtllm/run.sh create mode 100644 src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py create mode 100644 src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py create mode 100644 src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py create mode 100644 src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh diff --git a/src/f5_tts/runtime/triton_trtllm/Dockerfile.server b/src/f5_tts/runtime/triton_trtllm/Dockerfile.server new file mode 100644 index 0000000..b73bfc4 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/Dockerfile.server @@ -0,0 +1,9 @@ +FROM nvcr.io/nvidia/tritonserver:24.12-py3 +RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos +WORKDIR /workspace + + + + + + diff --git a/src/f5_tts/runtime/triton_trtllm/README.md b/src/f5_tts/runtime/triton_trtllm/README.md new file mode 100644 index 0000000..991ef59 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/README.md @@ -0,0 +1,44 @@ +## Triton Inference Serving Best Practice for F5 TTS + +### Quick Start +Directly launch the service using docker compose. +```sh +# TODO: support F5TTS_v1_Base +MODEL=F5TTS_Base docker compose up +``` + +### Build Image +Build the docker image from scratch. +```sh +docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12 +``` + +### Create Docker Container +```sh +your_mount_dir=/mnt:/mnt +docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12 +``` + +### Export Models to TensorRT-LLM and Launch Server +Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper). + +```sh +bash build_server.sh +``` + +### Benchmark using Dataset +```sh +num_task=2 +python3 client.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts +``` + +### Benchmark Results +Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs. + +| Model | Note | Concurrency | Avg Latency | RTF | +|-------|-----------|-----------------------|---------|--| +| F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394| + +### Credits +1. [F5-TTS](https://github.com/SWivid/F5-TTS) +2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/client_grpc.py b/src/f5_tts/runtime/triton_trtllm/client_grpc.py new file mode 100644 index 0000000..2f92ab6 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/client_grpc.py @@ -0,0 +1,491 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Nvidia (authors: Yuekai Zhang) +# 2023 Recurrent.ai (authors: Songtao Shi) +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script supports to load dataset from huggingface and sends it to the server +for decoding, in parallel. + +Usage: +num_task=2 + +# For offline F5-TTS +python3 client_grpc.py \ + --server-addr localhost \ + --model-name f5_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name test_zh \ + --log-dir ./log_concurrent_tasks_${num_task} + +# For offline Spark-TTS-0.5B +python3 client_grpc.py \ + --server-addr localhost \ + --model-name spark_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name wenetspeech4tts \ + --log-dir ./log_concurrent_tasks_${num_task} +""" + +import argparse +import asyncio +import json + +import os +import time +import types +from pathlib import Path + +import numpy as np +import soundfile as sf +import tritonclient +import tritonclient.grpc.aio as grpcclient +from tritonclient.utils import np_to_triton_dtype + + + +def write_triton_stats(stats, summary_file): + with open(summary_file, "w") as summary_f: + model_stats = stats["model_stats"] + # write a note, the log is from triton_client.get_inference_statistics(), to better human readability + summary_f.write( + "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" + ) + summary_f.write("To learn more about the log, please refer to: \n") + summary_f.write( + "1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n" + ) + summary_f.write( + "2. https://github.com/triton-inference-server/server/issues/5374 \n\n" + ) + summary_f.write( + "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" + ) + summary_f.write( + "However, there is a trade-off between the increased queue time and the increased batch size. \n" + ) + summary_f.write( + "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" + ) + summary_f.write( + "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" + ) + for model_state in model_stats: + if "last_inference" not in model_state: + continue + summary_f.write(f"model name is {model_state['name']} \n") + model_inference_stats = model_state["inference_stats"] + total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 + total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 + total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 + total_output_time_s = ( + int(model_inference_stats["compute_output"]["ns"]) / 1e9 + ) + summary_f.write( + f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa + ) + model_batch_stats = model_state["batch_stats"] + for batch in model_batch_stats: + batch_size = int(batch["batch_size"]) + compute_input = batch["compute_input"] + compute_output = batch["compute_output"] + compute_infer = batch["compute_infer"] + batch_count = int(compute_infer["count"]) + assert ( + compute_infer["count"] + == compute_output["count"] + == compute_input["count"] + ) + compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 + compute_input_time_ms = int(compute_input["ns"]) / 1e6 + compute_output_time_ms = int(compute_output["ns"]) / 1e6 + summary_f.write( + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms/batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms/batch_count/batch_size:.2f} ms \n" # noqa + ) + # summary_f.write( + # f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa + # ) + # summary_f.write( + # f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa + # ) + + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=8001, + help="Grpc port of the triton server, default is 8001", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default=None, + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--huggingface-dataset", + type=str, + default="yuekai/seed_tts", + help="dataset name in huggingface dataset hub", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="dataset split name, default is 'test'", + ) + + parser.add_argument( + "--manifest-path", + type=str, + default=None, + help="Path to the manifest dir which includes wav.scp trans.txt files.", + ) + + parser.add_argument( + "--model-name", + type=str, + default="f5_tts", + choices=[ + "f5_tts", "spark_tts" + ], + help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + ) + + parser.add_argument( + "--num-tasks", + type=int, + default=1, + help="Number of concurrent tasks for sending", + ) + + parser.add_argument( + "--log-interval", + type=int, + default=5, + help="Controls how frequently we print the log.", + ) + + parser.add_argument( + "--compute-wer", + action="store_true", + default=False, + help="""True to compute WER. + """, + ) + + parser.add_argument( + "--log-dir", + type=str, + required=False, + default="./tmp", + help="log directory", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Inference batch_size per request for offline mode.", + ) + + return parser.parse_args() + + +def load_audio(wav_path, target_sample_rate=16000): + assert target_sample_rate == 16000, "hard coding in server" + if isinstance(wav_path, dict): + waveform = wav_path["array"] + sample_rate = wav_path["sampling_rate"] + else: + waveform, sample_rate = sf.read(wav_path) + if sample_rate != target_sample_rate: + from scipy.signal import resample + num_samples = int(len(waveform) * (target_sample_rate / sample_rate)) + waveform = resample(waveform, num_samples) + return waveform, target_sample_rate + +async def send( + manifest_item_list: list, + name: str, + triton_client: tritonclient.grpc.aio.InferenceServerClient, + protocol_client: types.ModuleType, + log_interval: int, + model_name: str, + padding_duration: int = None, + audio_save_dir: str = "./", +): + total_duration = 0.0 + results = [] + latency_data = [] + task_id = int(name[5:]) + + print(f"manifest_item_list: {manifest_item_list}") + for i, item in enumerate(manifest_item_list): + if i % log_interval == 0: + print(f"{name}: {i}/{len(manifest_item_list)}") + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + duration = len(waveform) / sample_rate + lengths = np.array([[len(waveform)]], dtype=np.int32) + + reference_text, target_text = item["reference_text"], item["target_text"] + + estimated_target_duration = duration / len(reference_text) * len(target_text) + + if padding_duration: + # padding to nearset 10 seconds + samples = np.zeros( + ( + 1, + padding_duration + * sample_rate + * ((int(duration) // padding_duration) + 1), + ), + dtype=np.float32, + ) + + samples[0, : len(waveform)] = waveform + else: + samples = waveform + + samples = samples.reshape(1, -1).astype(np.float32) + + inputs = [ + protocol_client.InferInput( + "reference_wav", samples.shape, np_to_triton_dtype(samples.dtype) + ), + protocol_client.InferInput( + "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype) + ), + protocol_client.InferInput("reference_text", [1, 1], "BYTES"), + protocol_client.InferInput("target_text", [1, 1], "BYTES") + ] + inputs[0].set_data_from_numpy(samples) + inputs[1].set_data_from_numpy(lengths) + + input_data_numpy = np.array([reference_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[2].set_data_from_numpy(input_data_numpy) + + input_data_numpy = np.array([target_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[3].set_data_from_numpy(input_data_numpy) + + outputs = [protocol_client.InferRequestedOutput("waveform")] + + sequence_id = 100000000 + i + task_id * 10 + start = time.time() + response = await triton_client.infer( + model_name, inputs, request_id=str(sequence_id), outputs=outputs + ) + + audio = response.as_numpy("waveform").reshape(-1) + + end = time.time() - start + + audio_save_path = os.path.join( + audio_save_dir, f"{item['target_audio_path']}.wav" + ) + sf.write(audio_save_path, audio, 16000, "PCM_16") + + latency_data.append((end, estimated_target_duration)) + total_duration += estimated_target_duration + + return total_duration, latency_data + +def load_manifests(manifest_path): + with open(manifest_path, "r") as f: + manifest_list = [] + for line in f: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) + manifest_list.append( + { + "audio_filepath": prompt_wav, + "reference_text": prompt_text, + "target_text": gt_text, + "target_audio_path": utt + } + ) + return manifest_list + + +def split_data(data, k): + n = len(data) + if n < k: + print( + f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}." + ) + k = n + + quotient = n // k + remainder = n % k + + result = [] + start = 0 + for i in range(k): + if i < remainder: + end = start + quotient + 1 + else: + end = start + quotient + + result.append(data[start:end]) + start = end + + return result + + +async def main(): + args = get_args() + url = f"{args.server_addr}:{args.server_port}" + + triton_client = grpcclient.InferenceServerClient(url=url, verbose=False) + protocol_client = grpcclient + + if args.reference_audio: + args.num_tasks = 1 + args.log_interval = 1 + manifest_item_list = [ + { + "reference_text": args.reference_text, + "target_text": args.target_text, + "audio_filepath": args.reference_audio, + "target_audio_path": "test", + } + ] + elif args.huggingface_dataset: + import datasets + + dataset = datasets.load_dataset( + args.huggingface_dataset, + split=args.split_name, + trust_remote_code=True, + ) + manifest_item_list = [] + for i in range(len(dataset)): + manifest_item_list.append( + { + "audio_filepath": dataset[i]["prompt_audio"], + "reference_text": dataset[i]["prompt_text"], + "target_audio_path": dataset[i]["id"], + "target_text": dataset[i]["target_text"], + } + ) + else: + manifest_item_list = load_manifests(args.manifest_path) + + args.num_tasks = min(args.num_tasks, len(manifest_item_list)) + manifest_item_list = split_data(manifest_item_list, args.num_tasks) + + os.makedirs(args.log_dir, exist_ok=True) + tasks = [] + start_time = time.time() + for i in range(args.num_tasks): + task = asyncio.create_task( + send( + manifest_item_list[i], + name=f"task-{i}", + triton_client=triton_client, + protocol_client=protocol_client, + log_interval=args.log_interval, + model_name=args.model_name, + audio_save_dir=args.log_dir, + padding_duration=1.0, + ) + ) + tasks.append(task) + + ans_list = await asyncio.gather(*tasks) + + end_time = time.time() + elapsed = end_time - start_time + + + total_duration = 0.0 + latency_data = [] + for ans in ans_list: + total_duration += ans[0] + latency_data += ans[1] + + rtf = elapsed / total_duration + + s = f"RTF: {rtf:.4f}\n" + s += f"total_duration: {total_duration:.3f} seconds\n" + s += f"({total_duration/3600:.2f} hours)\n" + s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n" + + latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] + latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 + latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 + s += f"latency_variance: {latency_variance:.2f}\n" + s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n" + s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n" + s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n" + s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n" + s += f"average_latency_ms: {latency_ms:.2f}\n" + + print(s) + if args.manifest_path: + name = Path(args.manifest_path).stem + elif args.split_name: + name = args.split_name + with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: + f.write(s) + + stats = await triton_client.get_inference_statistics(model_name="", as_json=True) + write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + + metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True) + with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: + json.dump(metadata, f, indent=4) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/f5_tts/runtime/triton_trtllm/docker-compose.yml b/src/f5_tts/runtime/triton_trtllm/docker-compose.yml new file mode 100644 index 0000000..b08bd08 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/docker-compose.yml @@ -0,0 +1,20 @@ +services: + tts: + image: soar97/triton-f5-tts:24.12 + shm_size: '1gb' + ports: + - "8000:8000" + - "8001:8001" + - "8002:8002" + environment: + - PYTHONIOENCODING=utf-8 + - MODEL_ID=${MODEL_ID} + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] + command: > + /bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL" diff --git a/src/f5_tts/runtime/triton_trtllm/fill_template.py b/src/f5_tts/runtime/triton_trtllm/fill_template.py new file mode 100644 index 0000000..584a9f4 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/fill_template.py @@ -0,0 +1,42 @@ +#! /usr/bin/env python3 +from argparse import ArgumentParser +from string import Template + + +def main(file_path, substitutions, in_place, participant_ids): + with open(file_path) as f: + pbtxt = Template(f.read()) + + sub_dict = {"max_queue_size": 0} + sub_dict["participant_ids"] = participant_ids + for sub in substitutions.split(","): + key, value = sub.split(":") + sub_dict[key] = value + + pbtxt = pbtxt.safe_substitute(sub_dict) + + if in_place: + with open(file_path, "w") as f: + f.write(pbtxt) + else: + print(pbtxt) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("file_path", help="path of the .pbtxt to modify") + parser.add_argument( + "substitutions", + help= + "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + ) + parser.add_argument("--in_place", + "-i", + action="store_true", + help="do the operation in-place") + parser.add_argument("--participant_ids", + help="Participant IDs for the model", + default="") + args = parser.parse_args() + + main(**vars(args)) diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py new file mode 100644 index 0000000..cb7f2d2 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py @@ -0,0 +1,431 @@ + +import tensorrt as trt +import os +import math +import time +from typing import List, Dict, Union, Optional +from functools import wraps + +import tensorrt_llm +from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper +from tensorrt_llm.runtime.session import Session + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def remove_tensor_padding(input_tensor, + input_tensor_lengths=None): + # Audio tensor case: batch, seq_len, feature_len + # position_ids case: batch, seq_len + assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor" + + # Initialize a list to collect valid sequences + valid_sequences = [] + + for i in range(input_tensor.shape[0]): + valid_length = input_tensor_lengths[i] + valid_sequences.append(input_tensor[i, :valid_length]) + + # Concatenate all valid sequences along the batch dimension + output_tensor = torch.cat(valid_sequences, dim=0).contiguous() + return output_tensor + +class TextEmbedding(nn.Module): + def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2, precompute_max_pos = 4096): + super().__init__() + self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False) + self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) + + def forward(self, text): + # only keep tensors with value not -1 + text_mask = text != -1 + text_pad_cut_off_index = text_mask.sum(dim=1).max() + + text = text[:, :text_pad_cut_off_index] + text = self.text_embed(text) + text = text + self.freqs_cis[:text.shape[1], :] + for block in self.text_blocks: + text = block(text) + # padding text to the original length + # text shape: B,seq_len,C + # pad at the second dimension + text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0) + return text + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + +def load_checkpoint(ckpt_path, use_ema=True): + checkpoint = torch.load(ckpt_path, weights_only=True) + if use_ema: + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + dict_state = checkpoint["model_state_dict"] + text_embed_dict = {} + for key in dict_state.keys(): + # transformer.text_embed.text_embed.weight -> text_embed.weight + if 'text_embed' in key: + text_embed_dict[key.replace('transformer.text_embed.', '')] = dict_state[key] + return text_embed_dict + +class F5TTS(object): + + def __init__( + self, + config, + debug_mode=True, + stream: Optional[torch.cuda.Stream] = None, + tllm_model_dir: Optional[str] = None, + model_path: Optional[str] = None, + vocab_size: Optional[int] = None, + ): + self.dtype = config['pretrained_config']['dtype'] + + rank = tensorrt_llm.mpi_rank() + world_size = config['pretrained_config']['mapping']['world_size'] + cp_size = config['pretrained_config']['mapping']['cp_size'] + tp_size = config['pretrained_config']['mapping']['tp_size'] + pp_size = config['pretrained_config']['mapping']['pp_size'] + assert pp_size == 1 + self.mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + cp_size=cp_size, + tp_size=tp_size, + pp_size=1, + gpus_per_node=1 + ) + + local_rank = rank % self.mapping.gpus_per_node + self.device = torch.device(f"cuda:{local_rank}") + + torch.cuda.set_device(self.device) + + self.stream = stream + if self.stream is None: + self.stream = torch.cuda.Stream(self.device) + torch.cuda.set_stream(self.stream) + + engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine") + logger.info(f'Loading engine from {engine_file}') + with open(engine_file, "rb") as f: + engine_buffer = f.read() + + assert engine_buffer is not None + + self.session = Session.from_serialized_engine(engine_buffer) + + self.debug_mode = debug_mode + + self.inputs = {} + self.outputs = {} + self.buffer_allocated = False + + expected_tensor_names = ['noise', 'cond', 'time', 'rope_cos', 'rope_sin', 'input_lengths', 'denoised'] + + found_tensor_names = [ + self.session.engine.get_tensor_name(i) + for i in range(self.session.engine.num_io_tensors) + ] + if not self.debug_mode and set(expected_tensor_names) != set( + found_tensor_names): + logger.error( + f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}" + ) + logger.error( + f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}" + ) + logger.error(f"Expected tensor names: {expected_tensor_names}") + logger.error(f"Found tensor names: {found_tensor_names}") + raise RuntimeError( + "Tensor names in engine are not the same as expected.") + if self.debug_mode: + self.debug_tensors = list( + set(found_tensor_names) - set(expected_tensor_names)) + + self.max_mel_len = 4096 + self.text_embedding = TextEmbedding( + text_num_embeds=vocab_size, + text_dim=512, + conv_layers=4, + precompute_max_pos=self.max_mel_len + ).to(self.device) + self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True) + + self.target_audio_sample_rate = 24000 + self.target_rms = 0.15 # target rms for audio + self.n_fft = 1024 + self.win_length = 1024 + self.hop_length = 256 + self.n_mel_channels = 100 + #self.max_mel_len = 3000 + self.head_dim = 64 + self.base_rescale_factor = 1.0 + self.interpolation_factor = 1.0 + base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) + freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor + self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0) + self.rope_cos = self.freqs.cos().half() + self.rope_sin = self.freqs.sin().half() + self.nfe_steps = 16 + t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32) + time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t) + delta_t = torch.diff(time_step) + # WAR: hard coding 256 here + tmp_dim = 256 + time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32) + half_dim = tmp_dim // 2 + emb_factor = math.log(10000) / (half_dim - 1) + emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor) + for i in range(self.nfe_steps): + emb = time_step[i] * emb_factor + time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1) + self.time_expand = time_expand.to(self.device) + self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device) + + def _tensor_dtype(self, name): + # return torch dtype given tensor name for convenience + dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name)) + return dtype + + def _setup(self, batch_size, seq_len): + for i in range(self.session.engine.num_io_tensors): + name = self.session.engine.get_tensor_name(i) + if self.session.engine.get_tensor_mode( + name) == trt.TensorIOMode.OUTPUT: + shape = list(self.session.engine.get_tensor_shape(name)) + shape[0] = batch_size + shape[1] = seq_len + self.outputs[name] = torch.empty(shape, + dtype=self._tensor_dtype(name), + device=self.device) + + self.buffer_allocated = True + + def cuda_stream_guard(func): + """Sync external stream and set current stream to the one bound to the session. Reset on exit. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + external_stream = torch.cuda.current_stream() + if external_stream != self.stream: + external_stream.synchronize() + torch.cuda.set_stream(self.stream) + ret = func(self, *args, **kwargs) + if external_stream != self.stream: + self.stream.synchronize() + torch.cuda.set_stream(external_stream) + return ret + + return wrapper + + @cuda_stream_guard + def forward(self, noise: torch.Tensor, cond: torch.Tensor, + time_expand: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, + input_lengths: torch.Tensor, delta_t: torch.Tensor, use_perf: bool = False): + if use_perf: + torch.cuda.nvtx.range_push("flow matching") + cfg_strength = 2.0 + batch_size = noise.shape[0] + half_batch = batch_size // 2 + noise_half = noise[:half_batch] # Store the initial half of noise + + input_type = str_dtype_to_torch(self.dtype) + + # Keep a copy of the initial tensors + cond = cond.to(input_type) + rope_cos = rope_cos.to(input_type) + rope_sin = rope_sin.to(input_type) + input_lengths = input_lengths.to(str_dtype_to_torch("int32")) + + # Instead of iteratively updating noise within a single model context, + # we'll do a single forward pass for each iteration with fresh context setup + for i in range(self.nfe_steps): + # Re-setup the buffers for clean execution + self._setup(batch_size, noise.shape[1]) + if not self.buffer_allocated: + raise RuntimeError('Buffer not allocated, please call setup first!') + + # Re-create combined noises for this iteration + current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type) + + # Get time step for this iteration + current_time = time_expand[:, i].to(input_type) + + # Create fresh input dictionary for this iteration + current_inputs = { + 'noise': current_noise, + 'cond': cond, + 'time': current_time, + 'rope_cos': rope_cos, + 'rope_sin': rope_sin, + 'input_lengths': input_lengths, + } + + # Update inputs and set shapes + self.inputs.clear() # Clear previous inputs + self.inputs.update(**current_inputs) + self.session.set_shapes(self.inputs) + + if use_perf: + torch.cuda.nvtx.range_push(f"execute {i}") + ok = self.session.run(self.inputs, self.outputs, + self.stream.cuda_stream) + assert ok, "Failed to execute model" + # self.session.context.execute_async_v3(self.stream.cuda_stream) + if use_perf: + torch.cuda.nvtx.range_pop() + # Process results + t_scale = delta_t[i].unsqueeze(0).to(input_type) + + # Extract predictions + pred_cond = self.outputs["denoised"][:half_batch] + pred_uncond = self.outputs["denoised"][half_batch:] + + # Apply classifier-free guidance with safeguards + guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength + # Calculate update for noise + noise_half = noise_half + guidance * t_scale + if use_perf: + torch.cuda.nvtx.range_pop() + return noise_half + + def sample(self, text_pad_sequence: torch.Tensor, ref_mel_batch: torch.Tensor, + ref_mel_len_batch: torch.Tensor, estimated_reference_target_mel_len: List[int], remove_input_padding: bool = False, use_perf: bool = False): + if use_perf: + torch.cuda.nvtx.range_push("text embedding") + batch = text_pad_sequence.shape[0] + max_seq_len = ref_mel_batch.shape[1] + + text_pad_sequence_drop = torch.cat( + (text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), + dim=0 + ) + + text_embedding_drop_list = [] + for i in range(batch+1): + text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device))) + text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0) + + text_embedding = text_embedding_drop_condition[:-1] + # text_embedding_drop B,T,C batch should be the same + text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1) + + noise = torch.randn_like(ref_mel_batch).to(self.device) + rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1) + rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1) + + cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1) + cat_mel_text_drop = torch.cat( + (torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device), text_embedding_drop), + dim=-1 + ) + + time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous() + + # Convert estimated_reference_target_mel_len to tensor + input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32) + + # combine above along the batch dimension + inputs = { + 'noise': torch.cat((noise, noise), dim=0).contiguous(), + 'cond': torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(), + 'time_expand': time_expand, + 'rope_cos': torch.cat((rope_cos, rope_cos), dim=0).contiguous(), + 'rope_sin': torch.cat((rope_sin, rope_sin), dim=0).contiguous(), + 'input_lengths': torch.cat((input_lengths, input_lengths), dim=0).contiguous(), + 'delta_t': self.delta_t + } + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_push("remove input padding") + if remove_input_padding: + max_seq_len = inputs['cond'].shape[1] + inputs['noise'] = remove_tensor_padding(inputs['noise'], inputs['input_lengths']) + inputs['cond'] = remove_tensor_padding(inputs['cond'], inputs['input_lengths']) + # for time_expand, convert from B,D to B,T,D by repeat + inputs['time_expand'] = inputs['time_expand'].unsqueeze(1).repeat(1, max_seq_len, 1, 1) + inputs['time_expand'] = remove_tensor_padding(inputs['time_expand'], inputs['input_lengths']) + inputs['rope_cos'] = remove_tensor_padding(inputs['rope_cos'], inputs['input_lengths']) + inputs['rope_sin'] = remove_tensor_padding(inputs['rope_sin'], inputs['input_lengths']) + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_pop() + for key in inputs: + inputs[key] = inputs[key].to(self.device) + if use_perf: + torch.cuda.nvtx.range_pop() + start_time = time.time() + denoised = self.forward(**inputs, use_perf=use_perf) + cost_time = time.time() - start_time + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_push("remove input padding output") + if remove_input_padding: + denoised_list = [] + start_idx = 0 + for i in range(batch): + denoised_list.append(denoised[start_idx:start_idx+inputs['input_lengths'][i]]) + start_idx += inputs['input_lengths'][i] + if use_perf and remove_input_padding: + torch.cuda.nvtx.range_pop() + return denoised_list, cost_time + return denoised, cost_time \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py new file mode 100644 index 0000000..8337185 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py @@ -0,0 +1,271 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F +from torch.utils.dlpack import from_dlpack, to_dlpack +import torchaudio +import jieba +import triton_python_backend_utils as pb_utils +from pypinyin import Style, lazy_pinyin +import math +import os +from f5_tts_trtllm import F5TTS +torch.manual_seed(0) + +def get_tokenizer(vocab_file_path: str): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + with open(vocab_file_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + return vocab_char_map, vocab_size + +def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): + final_reference_target_texts_list = [] + custom_trans = str.maketrans( + {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"} + ) # add custom trans here, to address oov + + def is_chinese(c): + return "\u3100" <= c <= "\u9fff" # common chinese characters + + for text in reference_target_texts_list: + char_list = [] + text = text.translate(custom_trans) + for seg in jieba.cut(text): + seg_byte_len = len(bytes(seg, "UTF-8")) + if seg_byte_len == len(seg): # if pure alphabets and symbols + if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": + char_list.append(" ") + char_list.extend(seg) + elif polyphone and seg_byte_len == 3 * len( + seg + ): # if pure east asian characters + seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) + for i, c in enumerate(seg): + if is_chinese(c): + char_list.append(" ") + char_list.append(seg_[i]) + else: # if mixed characters, alphabets and symbols + for c in seg: + if ord(c) < 256: + char_list.extend(c) + elif is_chinese(c): + char_list.append(" ") + char_list.extend( + lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) + ) + else: + char_list.append(c) + final_reference_target_texts_list.append(char_list) + + return final_reference_target_texts_list + +def list_str_to_idx( + text: list[str] | list[list[str]], + vocab_char_map: dict[str, int], # {char: idx} + padding_value=-1, +): # noqa: F722 + list_idx_tensors = [ + torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text + ] # pinyin or char style + # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + return list_idx_tensors + +class TritonPythonModel: + def initialize(self, args): + self.use_perf = True + self.device = torch.device("cuda") + self.target_audio_sample_rate = 24000 + self.target_rms = 0.15 # target rms for audio + self.n_fft = 1024 + self.win_length = 1024 + self.hop_length = 256 + self.n_mel_channels = 100 + self.max_mel_len = 3000 + self.head_dim = 64 + + + parameters = json.loads(args['model_config'])['parameters'] + for key, value in parameters.items(): + parameters[key] = value["string_value"] + + self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"]) + self.reference_sample_rate = int(parameters["reference_audio_sample_rate"]) + self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate) + + self.tllm_model_dir = parameters["tllm_model_dir"] + config_file = os.path.join(self.tllm_model_dir, 'config.json') + with open(config_file) as f: + config = json.load(f) + self.model = F5TTS(config, debug_mode=False, tllm_model_dir=self.tllm_model_dir, model_path=parameters["model_path"], vocab_size=self.vocab_size) + + self.vocoder = parameters["vocoder"] + assert self.vocoder in ["vocos", "bigvgan"] + if self.vocoder == "vocos": + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_audio_sample_rate, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + n_mels=self.n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(self.device) + self.compute_mel_fn = self.get_vocos_mel_spectrogram + elif self.vocoder == "bigvgan": + self.compute_mel_fn = self.get_bigvgan_mel_spectrogram + + def get_vocos_mel_spectrogram(self, waveform): + mel = self.mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel.transpose(1, 2) + + def forward_vocoder(self, mel): + mel = mel.to(torch.float32).contiguous() + input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)) + + inference_request = pb_utils.InferenceRequest( + model_name='vocoder', + requested_output_names=['waveform'], + inputs=[input_tensor_0]) + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + else: + waveform = pb_utils.get_output_tensor_by_name(inference_response, + 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def execute(self, requests): + reference_text_list, target_text_list, reference_target_texts_list, reference_wavs_tensor, estimated_reference_target_mel_len, reference_mel_len = [], [], [], [], [], [] + max_wav_len = 0 + mel_features_list = [] + if self.use_perf: + torch.cuda.nvtx.range_push("preprocess") + for request in requests: + wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav") + wav_lens = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len") + + reference_text = pb_utils.get_input_tensor_by_name( + request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + reference_text_list.append(reference_text) + target_text = pb_utils.get_input_tensor_by_name( + request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + target_text_list.append(target_text) + + text = reference_text + target_text + reference_target_texts_list.append(text) + + wav = from_dlpack(wav_tensor.to_dlpack()) + wav_len = from_dlpack(wav_lens.to_dlpack()) + wav_len = wav_len.squeeze() + assert wav.shape[0] == 1, "Only support batch size 1 for now." + wav = wav[:, :wav_len] + + ref_rms = torch.sqrt(torch.mean(torch.square(wav))) + if ref_rms < self.target_rms: + wav = wav * self.target_rms / ref_rms + if self.reference_sample_rate != self.target_audio_sample_rate: + wav = self.resampler(wav) + wav = wav.to(self.device) + if self.use_perf: + torch.cuda.nvtx.range_push("compute_mel") + mel_features = self.compute_mel_fn(wav) + if self.use_perf: + torch.cuda.nvtx.range_pop() + mel_features_list.append(mel_features) + + reference_mel_len.append(mel_features.shape[1]) + estimated_reference_target_mel_len.append(int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))) + + max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len) + + batch = len(requests) + print(f"The current batch is {batch}") + mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device) + for i, mel in enumerate(mel_features_list): + mel_features[i, :mel.shape[1], :] = mel + + reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device) + + pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) + text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map) + + for i, item in enumerate(text_pad_sequence): + text_pad_sequence[i] = F.pad(item, (0, estimated_reference_target_mel_len[i] - len(item)), mode='constant', value=-1) + text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS + text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device) + text_pad_sequence = F.pad(text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode='constant', value=-1) + if self.use_perf: + torch.cuda.nvtx.range_pop() + + denoised, cost_time = self.model.sample( + text_pad_sequence, + mel_features, + reference_mel_len_tensor, + estimated_reference_target_mel_len, + remove_input_padding=False, + use_perf=self.use_perf, + ) + if self.use_perf: + torch.cuda.nvtx.range_push("vocoder") + + responses = [] + for i in range(batch): + ref_me_len = reference_mel_len[i] + estimated_mel_len = estimated_reference_target_mel_len[i] + denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2) + audio = self.forward_vocoder(denoised_one_item) + rms = torch.sqrt(torch.mean(torch.square(audio))) + if rms < self.target_rms: + audio = audio * self.target_rms / rms + + audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio]) + responses.append(inference_response) + if self.use_perf: + torch.cuda.nvtx.range_pop() + return responses diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt new file mode 100644 index 0000000..171211e --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "f5_tts" +backend: "python" +max_batch_size: 1 +dynamic_batching { + max_queue_delay_microseconds: 1 +} +parameters [ + { + key: "vocab_file" + value: { string_value: "${vocab}"} + }, + { + key: "model_path", + value: {string_value:"${model}"} + }, + { + key: "tllm_model_dir", + value: {string_value:"${trtllm}"} + }, + { + key: "reference_audio_sample_rate", + value: {string_value:"16000"} + }, + { + key: "vocoder", + value: {string_value:"${vocoder}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + optional: True + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + optional: True + }, + { + name: "reference_text" + data_type: TYPE_STRING + dims: [1] + }, + { + name: "target_text" + data_type: TYPE_STRING + dims: [1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/1/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt new file mode 100644 index 0000000..9a30b52 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/vocoder/config.pbtxt @@ -0,0 +1,32 @@ +name: "vocoder" +backend: "tensorrt" +default_model_filename: "vocoder.plan" +max_batch_size: 4 + +input [ + { + name: "mel" + data_type: TYPE_FP32 + dims: [ 100, -1 ] + } +] + +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +dynamic_batching { + preferred_batch_size: [1, 2, 4] + max_queue_delay_microseconds: 1 +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/patch/__init__.py b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py new file mode 100644 index 0000000..d43cacc --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .baichuan.model import BaichuanForCausalLM +from .bert.model import (BertForQuestionAnswering, + BertForSequenceClassification, BertModel, + RobertaForQuestionAnswering, + RobertaForSequenceClassification, RobertaModel) +from .bloom.model import BloomForCausalLM, BloomModel +from .chatglm.config import ChatGLMConfig +from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel +from .cogvlm.config import CogVLMConfig +from .cogvlm.model import CogVLMForCausalLM +from .commandr.model import CohereForCausalLM +from .dbrx.config import DbrxConfig +from .dbrx.model import DbrxForCausalLM +from .deepseek_v1.model import DeepseekForCausalLM +from .deepseek_v2.model import DeepseekV2ForCausalLM +from .dit.model import DiT +from .eagle.model import EagleForCausalLM +from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder +from .falcon.config import FalconConfig +from .falcon.model import FalconForCausalLM, FalconModel +from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig +from .gemma.model import GemmaForCausalLM +from .gpt.config import GPTConfig +from .gpt.model import GPTForCausalLM, GPTModel +from .gptj.config import GPTJConfig +from .gptj.model import GPTJForCausalLM, GPTJModel +from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel +from .grok.model import GrokForCausalLM +from .llama.config import LLaMAConfig +from .llama.model import LLaMAForCausalLM, LLaMAModel +from .mamba.model import MambaForCausalLM +from .medusa.config import MedusaConfig +from .medusa.model import MedusaForCausalLm +from .mllama.model import MLLaMAModel +from .modeling_utils import (PretrainedConfig, PretrainedModel, + SpeculativeDecodingMode) +from .mpt.model import MPTForCausalLM, MPTModel +from .nemotron_nas.model import DeciLMForCausalLM +from .opt.model import OPTForCausalLM, OPTModel +from .phi3.model import Phi3ForCausalLM, Phi3Model +from .phi.model import PhiForCausalLM, PhiModel +from .qwen.model import QWenForCausalLM +from .recurrentgemma.model import RecurrentGemmaForCausalLM +from .redrafter.model import ReDrafterForCausalLM +from .f5tts.model import F5TTS + +__all__ = [ + 'BertModel', + 'BertForQuestionAnswering', + 'BertForSequenceClassification', + 'RobertaModel', + 'RobertaForQuestionAnswering', + 'RobertaForSequenceClassification', + 'BloomModel', + 'BloomForCausalLM', + 'DiT', + 'DeepseekForCausalLM', + 'FalconConfig', + 'DeepseekV2ForCausalLM', + 'FalconForCausalLM', + 'FalconModel', + 'GPTConfig', + 'GPTModel', + 'GPTForCausalLM', + 'OPTForCausalLM', + 'OPTModel', + 'LLaMAConfig', + 'LLaMAForCausalLM', + 'LLaMAModel', + 'MedusaConfig', + 'MedusaForCausalLm', + 'ReDrafterForCausalLM', + 'GPTJConfig', + 'GPTJModel', + 'GPTJForCausalLM', + 'GPTNeoXModel', + 'GPTNeoXForCausalLM', + 'PhiModel', + 'PhiConfig', + 'Phi3Model', + 'Phi3Config', + 'PhiForCausalLM', + 'Phi3ForCausalLM', + 'ChatGLMConfig', + 'ChatGLMForCausalLM', + 'ChatGLMModel', + 'BaichuanForCausalLM', + 'QWenConfig' + 'QWenForCausalLM', + 'QWenModel', + 'EncoderModel', + 'DecoderModel', + 'PretrainedConfig', + 'PretrainedModel', + 'WhisperEncoder', + 'MambaForCausalLM', + 'MambaConfig', + 'MPTForCausalLM', + 'MPTModel', + 'SkyworkForCausalLM', + 'GemmaConfig', + 'GemmaForCausalLM', + 'DbrxConfig', + 'DbrxForCausalLM', + 'RecurrentGemmaForCausalLM', + 'CogVLMConfig', + 'CogVLMForCausalLM', + 'EagleForCausalLM', + 'SpeculativeDecodingMode', + 'CohereForCausalLM', + 'MLLaMAModel', + 'F5TTS', +] + +MODEL_MAP = { + 'GPT2LMHeadModel': GPTForCausalLM, + 'GPT2LMHeadCustomModel': GPTForCausalLM, + 'GPTBigCodeForCausalLM': GPTForCausalLM, + 'Starcoder2ForCausalLM': GPTForCausalLM, + 'FuyuForCausalLM': GPTForCausalLM, + 'Kosmos2ForConditionalGeneration': GPTForCausalLM, + 'JAISLMHeadModel': GPTForCausalLM, + 'GPTForCausalLM': GPTForCausalLM, + 'NemotronForCausalLM': GPTForCausalLM, + 'OPTForCausalLM': OPTForCausalLM, + 'BloomForCausalLM': BloomForCausalLM, + 'RWForCausalLM': FalconForCausalLM, + 'FalconForCausalLM': FalconForCausalLM, + 'PhiForCausalLM': PhiForCausalLM, + 'Phi3ForCausalLM': Phi3ForCausalLM, + 'Phi3VForCausalLM': Phi3ForCausalLM, + 'Phi3SmallForCausalLM': Phi3ForCausalLM, + 'PhiMoEForCausalLM': Phi3ForCausalLM, + 'MambaForCausalLM': MambaForCausalLM, + 'GPTNeoXForCausalLM': GPTNeoXForCausalLM, + 'GPTJForCausalLM': GPTJForCausalLM, + 'MPTForCausalLM': MPTForCausalLM, + 'GLMModel': ChatGLMForCausalLM, + 'ChatGLMModel': ChatGLMForCausalLM, + 'ChatGLMForCausalLM': ChatGLMForCausalLM, + 'LlamaForCausalLM': LLaMAForCausalLM, + 'ExaoneForCausalLM': LLaMAForCausalLM, + 'MistralForCausalLM': LLaMAForCausalLM, + 'MixtralForCausalLM': LLaMAForCausalLM, + 'ArcticForCausalLM': LLaMAForCausalLM, + 'Grok1ModelForCausalLM': GrokForCausalLM, + 'InternLMForCausalLM': LLaMAForCausalLM, + 'InternLM2ForCausalLM': LLaMAForCausalLM, + 'MedusaForCausalLM': MedusaForCausalLm, + 'ReDrafterForCausalLM': ReDrafterForCausalLM, + 'BaichuanForCausalLM': BaichuanForCausalLM, + 'BaiChuanForCausalLM': BaichuanForCausalLM, + 'SkyworkForCausalLM': LLaMAForCausalLM, + GEMMA_ARCHITECTURE: GemmaForCausalLM, + GEMMA2_ARCHITECTURE: GemmaForCausalLM, + 'QWenLMHeadModel': QWenForCausalLM, + 'QWenForCausalLM': QWenForCausalLM, + 'Qwen2ForCausalLM': QWenForCausalLM, + 'Qwen2MoeForCausalLM': QWenForCausalLM, + 'Qwen2ForSequenceClassification': QWenForCausalLM, + 'Qwen2VLForConditionalGeneration': QWenForCausalLM, + 'WhisperEncoder': WhisperEncoder, + 'EncoderModel': EncoderModel, + 'DecoderModel': DecoderModel, + 'DbrxForCausalLM': DbrxForCausalLM, + 'RecurrentGemmaForCausalLM': RecurrentGemmaForCausalLM, + 'CogVLMForCausalLM': CogVLMForCausalLM, + 'DiT': DiT, + 'DeepseekForCausalLM': DeepseekForCausalLM, + 'DeciLMForCausalLM': DeciLMForCausalLM, + 'DeepseekV2ForCausalLM': DeepseekV2ForCausalLM, + 'EagleForCausalLM': EagleForCausalLM, + 'CohereForCausalLM': CohereForCausalLM, + 'MllamaForConditionalGeneration': MLLaMAModel, + 'BertForQuestionAnswering': BertForQuestionAnswering, + 'BertForSequenceClassification': BertForSequenceClassification, + 'BertModel': BertModel, + 'RobertaModel': RobertaModel, + 'RobertaForQuestionAnswering': RobertaForQuestionAnswering, + 'RobertaForSequenceClassification': RobertaForSequenceClassification, + 'F5TTS': F5TTS +} diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py new file mode 100644 index 0000000..45bfec5 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -0,0 +1,254 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations +import sys +import os +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(current_file_path) +sys.path.append(parent_dir) +import math +import numpy as np +import torch +from torch import nn +import tensorrt as trt +from collections import OrderedDict +from ..._utils import str_dtype_to_trt, trt_dtype_to_str, trt_dtype_to_np +from ...plugin import current_all_reduce_helper +from ..modeling_utils import PretrainedConfig, PretrainedModel +from ...functional import (Tensor, allgather, arange, chunk, concat, constant, + cos, exp, expand, shape, silu, sin, slice, split, + unsqueeze, squeeze, cast) +from ...module import Module, ModuleList +from tensorrt_llm._common import default_net +from ...layers import Linear + +from .modules import ( + TimestepEmbedding, + # ConvNeXtV2Block, + ConvPositionEmbedding, + DiTBlock, + AdaLayerNormZero_Final, + # precompute_freqs_cis, get_pos_embed_indices, +) + +# Text embedding +# class TextEmbedding(Module): +# def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): +# super().__init__() +# self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token + +# if conv_layers > 0: +# self.extra_modeling = True +# self.precompute_max_pos = 4096 # ~44s of 24khz audio +# self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) +# self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) +# else: +# self.extra_modeling = False + +# def forward(self, text: int['b nt'], seq_len): +# text = self.text_embed(text) # b n -> b n d + +# # possible extra modeling +# if self.extra_modeling: +# # sinus pos emb +# pos_idx = get_pos_embed_indices(torch.zeros(1, dtype=torch.int32), seq_len, max_pos=self.precompute_max_pos) +# # convnextv2 blocks +# text = self.text_blocks(text + self.freqs_cis[pos_idx]) + +# return text + +class InputEmbedding(Module): + def __init__(self, mel_dim, text_dim, out_dim): + super().__init__() + self.proj = Linear(mel_dim * 2 + text_dim, out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim) + + def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False): + # if drop_audio_cond: # cfg for cond audio + x = self.proj(concat([x, cond], dim = -1)) + return self.conv_pos_embed(x) + x + +# Transformer backbone using DiT blocks +# class F5TTS(PretrainedModel): +# def __init__(self, config: PretrainedConfig): +# super().__init__(config) +# self.f5_transformer = DiT_transformer(config) +# self.dtype = str_dtype_to_trt(config.dtype) +# self.cfg_strength = 2 + +# def forward(self, +# noise: float['b n d'], # nosied input audio +# cond: float['b n d'], # masked cond audio +# cond_drop: float['b n d'], +# time: float['b n'], # time step +# rope_cos: float['b n d'], +# rope_sin: float['b n d'], +# t_scale: float['b'], +# mask: bool['b n'] | None = None): + +# pred = self.f5_transformer(x = noise, cond = cond, cond_drop = cond_drop, time = time, rope_cos = rope_cos, rope_sin = rope_sin, mask = mask) +# pred, pred1 = chunk(pred, 2, dim = 0), chunk works only for static tensor +# # cfg_strength = constant(np.array([self.cfg_strength], dtype = np.float32)).cast(noise.dtype) +# # noise = noise + (pred_cond + (pred_cond - pred_uncond) * cfg_strength) * t_scale +# noise.mark_output('denoised', self.dtype) +# return noise + + + +class F5TTS(PretrainedModel): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.dtype = str_dtype_to_trt(config.dtype) + + self.time_embed = TimestepEmbedding(config.hidden_size) # √ + if config.text_dim is None: + text_dim = config.mel_dim + self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) + + self.dim = config.hidden_size + self.depth = config.num_hidden_layers + self.transformer_blocks = ModuleList( + [ + DiTBlock( + dim = self.dim, + heads = config.num_attention_heads, + dim_head = config.dim_head, + ff_mult = config.ff_mult, + dropout = config.dropout + ) + for _ in range(self.depth) + ] + ) + + self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = Linear(config.hidden_size, config.mel_dim) + + def forward( + self, + noise: float['b n d'], # nosied input audio + cond: float['b n d'], # masked cond audio + time: float['b n'], # time step + rope_cos: float['b n d'] , + rope_sin: float['b n d'], + input_lengths: int['b'], + scale = 1.0 + ): + t = self.time_embed(time) + x = self.input_embed(noise, cond) + # x = concat([self.input_embed(x, cond), self.input_embed(x, cond_drop)], dim = 0) + + for block in self.transformer_blocks: + x = block(x, t, rope_cos = rope_cos, rope_sin = rope_sin, input_lengths=input_lengths, scale = scale) + denoise = self.proj_out(self.norm_out(x, t)) + denoise.mark_output('denoised', self.dtype) + return denoise + + def prepare_inputs(self, **kwargs): + max_batch_size = kwargs['max_batch_size'] + batch_size_range = [2, 2, max_batch_size] + mel_size = 100 + max_seq_len = 3000 + num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size] + hidden_size = 512 + concat_feature_dim = mel_size + hidden_size + freq_embed_dim=256 + head_dim = 64 + mapping = self.config.mapping + if mapping.tp_size > 1: + current_all_reduce_helper().set_workspace_tensor(mapping, 1) + if default_net().plugin_config.remove_input_padding: + noise = Tensor( + name='noise', + dtype=self.dtype, + shape=[-1, mel_size], + dim_range=OrderedDict([ + ('num_frames', [num_frames_range]), + ('n_mels', [mel_size]), + ])) + cond = Tensor( + name='cond', + dtype=self.dtype, + shape=[-1, concat_feature_dim], + dim_range=OrderedDict([ + ('num_frames', [num_frames_range]), + ('embeded_length', [concat_feature_dim]), + ])) + time = Tensor(name='time', + dtype=self.dtype, + shape=[-1, freq_embed_dim], + dim_range=OrderedDict([ + ('num_frames', [num_frames_range]), + ('freq_dim', [freq_embed_dim]), + ])) + rope_cos = Tensor(name='rope_cos', + dtype=self.dtype, + shape=[-1, head_dim], + dim_range=OrderedDict([ + ('num_frames', [num_frames_range]), + ('head_dim', [head_dim]), + ])) + rope_sin = Tensor(name='rope_sin', + dtype=self.dtype, + shape=[-1, head_dim], + dim_range=OrderedDict([ + ('num_frames', [num_frames_range]), + ('head_dim', [head_dim]), + ])) + + else: + noise = Tensor( + name='noise', + dtype=self.dtype, + shape=[-1, -1, mel_size], + dim_range=OrderedDict([ + ('batch_size', [batch_size_range]), + ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), + ('n_mels', [mel_size]), + ])) + cond = Tensor( + name='cond', + dtype=self.dtype, + shape=[-1, -1, concat_feature_dim], + dim_range=OrderedDict([ + ('batch_size', [batch_size_range]), + ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), + ('embeded_length', [concat_feature_dim]), + ])) + print(233333333333333333333333333333333333333333333333333, batch_size_range) + time = Tensor(name='time', + dtype=self.dtype, + shape=[-1, freq_embed_dim], + dim_range=OrderedDict([ + ('batch_size', [batch_size_range]), + ('freq_dim', [freq_embed_dim]), + ])) + rope_cos = Tensor(name='rope_cos', + dtype=self.dtype, + shape=[-1, -1, head_dim], + dim_range=OrderedDict([ + ('batch_size', [batch_size_range]), + ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), + ('head_dim', [head_dim]), + ])) + rope_sin = Tensor(name='rope_sin', + dtype=self.dtype, + shape=[-1, -1, head_dim], + dim_range=OrderedDict([ + ('batch_size', [batch_size_range]), + ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), + ('head_dim', [head_dim]), + ])) + input_lengths = Tensor( + name="input_lengths", + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([("batch_size", [batch_size_range])]), + ) + return {'noise': noise, 'cond': cond, 'time': time, 'rope_cos': rope_cos, 'rope_sin': rope_sin, 'input_lengths': input_lengths} \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py new file mode 100644 index 0000000..896e3d7 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +# import torchaudio +# from librosa.filters import mel as librosa_mel_fn +from torch import nn +import numpy as np +import tensorrt as trt +from tensorrt_llm._common import default_net +from ..._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs, + trt_dtype_to_np, trt_dtype_to_str,str_dtype_to_trt) +from ...functional import (Tensor, allgather, arange, chunk, concat, constant, + cos, exp, expand, shape, silu, sin, slice, split, permute, expand_mask, expand_dims_like, + unsqueeze, matmul, softmax, where, RopeEmbeddingUtils, minimum, repeat_interleave, squeeze, cast, gelu) +from ...functional import expand_dims, view, bert_attention +from ...layers import MLP, BertAttention, Conv2d, LayerNorm, Linear, Conv1d, Mish, embedding, RowLinear, ColumnLinear +from ...module import Module, ModuleList + +# class GRN(Module): +# def __init__(self, dim): +# super().__init__() +# self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) +# self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + +# def forward(self, x): +# Gx = torch.norm(x, p=2, dim=1, keepdim=True) +# Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) +# return self.gamma * (x * Nx) + self.beta + x + +class FeedForward(Module): + def __init__(self, dim, dim_out=None, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + self.project_in = Linear(dim, inner_dim) + self.ff = Linear(inner_dim, dim_out) + + def forward(self, x): + return self.ff(gelu(self.project_in(x))) + +# class ConvNeXtV2Block(Module): +# def __init__( +# self, +# dim: int, +# intermediate_dim: int, +# dilation: int = 1, +# ): +# super().__init__() +# padding = (dilation * (7 - 1)) // 2 +# self.dwconv = nn.Conv1d( +# dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation +# ) # depthwise conv +# self.norm = nn.LayerNorm(dim, eps=1e-6) +# self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers +# self.act = nn.GELU() +# self.grn = GRN(intermediate_dim) +# self.pwconv2 = nn.Linear(intermediate_dim, dim) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# residual = x +# x = x.transpose(1, 2) # b n d -> b d n +# x = self.dwconv(x) +# x = x.transpose(1, 2) # b d n -> b n d +# x = self.norm(x) +# x = self.pwconv1(x) +# x = self.act(x) +# x = self.grn(x) +# x = self.pwconv2(x) +# return residual + x + +# def get_pos_embed_indices(start, length, max_pos, scale=1.0): +# # length = length if isinstance(length, int) else length.max() +# scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar +# pos = ( +# unsqueeze(start, 1) +# + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() +# ) +# # avoid extra long error. +# pos = torch.where(pos < max_pos, pos, max_pos - 1) +# return pos + +# def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): +# # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning +# # has some connection to NTK literature +# # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ +# # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py +# theta *= theta_rescale_factor ** (dim / (dim - 2)) +# freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) +# t = torch.arange(end, device=freqs.device) # type: ignore +# freqs = torch.outer(t, freqs).float() # type: ignore +# freqs_cos = torch.cos(freqs) # real part +# freqs_sin = torch.sin(freqs) # imaginary part +# return torch.cat([freqs_cos, freqs_sin], dim=-1) + +class AdaLayerNormZero(Module): + def __init__(self, dim): + super().__init__() + + self.linear = Linear(dim, dim * 6) + self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1) + x = self.norm(x) + ones = constant(np.ones(1, dtype = np.float32)).cast(x.dtype) + if default_net().plugin_config.remove_input_padding: + x = x * (ones + scale_msa) + shift_msa + else: + x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1) + # x = x * unsqueeze((ones + scale_msa), 1) + unsqueeze(shift_msa, 1) + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + +class AdaLayerNormZero_Final(Module): + def __init__(self, dim): + super().__init__() + + self.linear = Linear(dim, dim * 2) + + self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(silu(emb)) + scale, shift = chunk(emb, 2, dim=1) + # scale ----> (1, 1024) + # x ----> (1, -1, 1024) + ones = constant(np.ones(1, dtype = np.float32)).cast(x.dtype) + if default_net().plugin_config.remove_input_padding: + x = self.norm(x) * (ones + scale) + shift + else: + x = self.norm(x) * unsqueeze((ones + scale), 1) + x = x + unsqueeze(shift, 1) + return x + +class ConvPositionEmbedding(Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) + self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) + self.mish = Mish() + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + # if mask is not None: + # mask = mask[..., None] + # x = x.masked_fill(~mask, 0.0) + if default_net().plugin_config.remove_input_padding: + x = unsqueeze(x, 0) + x = permute(x, [0, 2, 1]) + x = self.mish(self.conv1d2(self.mish(self.conv1d1(x)))) + out = permute(x, [0, 2, 1]) + if default_net().plugin_config.remove_input_padding: + out = squeeze(out, 0) + # if mask is not None: + # out = out.masked_fill(~mask, 0.0) + return out + +class Attention(Module): + def __init__( + self, + processor: AttnProcessor, + dim: int, + heads: int = 16, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only = None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.processor = processor + + self.dim = dim # hidden_size + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + self.attention_head_size = dim_head + self.context_dim = context_dim + self.context_pre_only = context_pre_only + self.tp_size = 1 + self.num_attention_heads = heads // self.tp_size + self.num_attention_kv_heads = heads // self.tp_size # 8 + self.dtype = str_dtype_to_trt('float32') + self.attention_hidden_size = self.attention_head_size * self.num_attention_heads + # self.to_q = Linear(dim, self.inner_dim) + self.to_q = ColumnLinear(dim, + self.tp_size * self.num_attention_heads *self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size) + self.to_k = ColumnLinear(dim, + self.tp_size * self.num_attention_heads *self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size) + self.to_v = ColumnLinear(dim, + self.tp_size * self.num_attention_heads *self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size) + + if self.context_dim is not None: + self.to_k_c = Linear(context_dim, self.inner_dim) + self.to_v_c = Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = Linear(context_dim, self.inner_dim) + + # self.to_out = Linear(self.inner_dim, dim) + self.to_out = RowLinear(self.tp_size * self.num_attention_heads * + self.attention_head_size, + dim, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size) + # self.to_out.append(Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = Linear(self.inner_dim, dim) + + + def forward( + self, + x: float['b n d'], # noised input x + rope_cos, + rope_sin, + input_lengths, + c: float['b n d'] = None, # context c + scale = 1.0, + rope=None, + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) + +def rotate_every_two_3dim(tensor: Tensor) -> Tensor: + shape_tensor = concat([ + shape(tensor, i) / 2 if i == (tensor.ndim() - + 1) else shape(tensor, i) + for i in range(tensor.ndim()) + ]) + if default_net().plugin_config.remove_input_padding: + assert tensor.ndim() == 2 + x1 = slice(tensor, [0, 0], shape_tensor, [1, 2]) + x2 = slice(tensor, [0, 1], shape_tensor, [1, 2]) + x1 = expand_dims(x1, 2) + x2 = expand_dims(x2, 2) + zero = constant( + np.ascontiguousarray( + np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + x2 = zero - x2 + x = concat([x2, x1], 2) + out = view( + x, concat([shape(x, 0), + shape(x, 1)*2])) + else: + assert tensor.ndim() == 3 + + x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2]) + x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2]) + x1 = expand_dims(x1, 3) + x2 = expand_dims(x2, 3) + zero = constant( + np.ascontiguousarray( + np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + x2 = zero - x2 + x = concat([x2, x1], 3) + out = view( + x, concat([shape(x, 0), + shape(x, 1), + shape(x, 2)*2])) + + return out + +def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin): + if default_net().plugin_config.remove_input_padding: + rot_dim = shape(rope_cos, -1) #64 + new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64) + x_ = slice(x, [0, 0], new_t_shape, [1, 1]) + end_dim = shape(x, -1) - shape(rope_cos, -1) + new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960) + x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1]) + out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim = -1) + else: + rot_dim = shape(rope_cos, 2) #64 + new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64) + x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1]) + end_dim = shape(x, 2) - shape(rope_cos, 2) + new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960) + x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1]) + out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim = -1) + # t -> (2,-1,1024) freqs -> (-1,64) + return out + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float['b n d'], # noised input x + rope_cos, + rope_sin, + input_lengths, + scale = 1.0, + rope=None, + ) -> torch.FloatTensor: + + + + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + # k,v,q all (2,1226,1024) + query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin) + key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + norm_factor = math.sqrt(attn.attention_head_size) + q_scaling = 1.0 / norm_factor + mask = None + if not default_net().plugin_config.remove_input_padding: + N = shape(x, 1) + B = shape(x, 0) + seq_len_2d = concat([1, N]) + max_position_embeddings = 4096 + # create position ids + position_ids_buffer = constant( + np.expand_dims( + np.arange(max_position_embeddings).astype(np.int32), + 0)) + tmp_position_ids = slice(position_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + tmp_position_ids = expand(tmp_position_ids, concat([B, N])) #BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 + tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) #BxL + mask = tmp_position_ids < tmp_input_lengths # BxL + mask = mask.cast('int32') + + if default_net().plugin_config.bert_attention_plugin: + qkv = concat([query, key, value], dim = -1) + # TRT plugin mode + assert input_lengths is not None + if default_net().plugin_config.remove_input_padding: + qkv = qkv.view( + concat([-1, 3 * inner_dim])) + max_input_length = constant( + np.zeros([ + 2048, + ], dtype=np.int32)) + print("============================================================================") + else: + max_input_length = None + print("******************************************************************************************************") + context = bert_attention(qkv, + input_lengths, + attn.num_attention_heads, + attn.attention_head_size, + q_scaling=q_scaling, + max_input_length=max_input_length) + else: + assert not default_net().plugin_config.remove_input_padding + + def transpose_for_scores(x): + new_x_shape = concat([ + shape(x, 0), + shape(x, 1), attn.num_attention_heads, attn.attention_head_size + ]) + + y = x.view(new_x_shape) + y = y.transpose(1, 2) + return y + + def transpose_for_scores_k(x): + new_x_shape = concat([ + shape(x, 0), + shape(x, 1), attn.num_attention_heads, attn.attention_head_size + ]) + + y = x.view(new_x_shape) + y = y.permute([0, 2, 3, 1]) + return y + + query = transpose_for_scores(query) + key = transpose_for_scores_k(key) + value = transpose_for_scores(value) + + attention_scores = matmul(query, key, use_fp32_acc=False) + + if mask is not None: + attention_mask = expand_mask(mask, shape(query, 2)) + attention_mask = cast(attention_mask, attention_scores.dtype) + attention_scores = attention_scores + attention_mask + + attention_probs = softmax(attention_scores, dim=-1) + + context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2) + context = context.view( + concat([ + shape(context, 0), + shape(context, 1), attn.attention_hidden_size + ])) + context = attn.to_out(context) + if mask is not None: + mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1])) + mask = expand_dims_like(mask, context) + mask = cast(mask, context.dtype) + # mask = where(mask ==0, 0.0, 1.0) + context = context * mask + return context + +# DiT Block +class DiTBlock(Module): + + def __init__(self, dim, heads, dim_head, ff_mult = 2, dropout = 0.1): + super().__init__() + + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention( + processor = AttnProcessor(), + dim = dim, + heads = heads, + dim_head = dim_head, + dropout = dropout, + ) + + self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout) + + def forward(self, x, t, rope_cos, rope_sin, input_lengths, scale = 1.0, rope = ModuleNotFoundError): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + # attention + # norm ----> (2,1226,1024) + attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) + + # process attention output for input x + if default_net().plugin_config.remove_input_padding: + x = x + gate_msa * attn_output + else: + x = x + unsqueeze(gate_msa, 1) * attn_output + ones = constant(np.ones(1, dtype = np.float32)).cast(x.dtype) + if default_net().plugin_config.remove_input_padding: + norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp + else: + norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1) + # norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp + ff_output = self.ff(norm) + if default_net().plugin_config.remove_input_padding: + x = x + gate_mlp * ff_output + else: + x = x + unsqueeze(gate_mlp, 1) * ff_output + + return x + +# class SinusPositionEmbedding(Module): +# def __init__(self, dim): +# super().__init__() +# self.dim = dim + +# def forward(self, x, scale=1000): +# half_dim = self.dim // 2 +# emb = math.log(10000) / (half_dim - 1) +# emb = exp(arange(start=0, end=half_dim, dtype=trt_dtype_to_str(trt.float32)) * - emb) +# emb = scale * unsqueeze(x, 1) * unsqueeze(emb, 0) +# emb = concat([cos(emb), sin(emb)], dim=-1) +# emb = emb.cast(x.dtype) +# assert self.dim % 2 == 0 +# return emb + +class TimestepEmbedding(Module): + def __init__(self, dim, freq_embed_dim=256, dtype=None): + super().__init__() + # self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype) + self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype) + + def forward(self, timestep: float["b n"]): # noqa: F821 + t_freq = self.mlp1(timestep) + t_freq = silu(t_freq) + t_emb = self.mlp2(t_freq) + return t_emb diff --git a/src/f5_tts/runtime/triton_trtllm/run.sh b/src/f5_tts/runtime/triton_trtllm/run.sh new file mode 100644 index 0000000..ed38497 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/run.sh @@ -0,0 +1,62 @@ + + +stage=$1 +stop_stage=$2 +model=$3 # F5TTS_Base +if [ -z "$model" ]; then + echo "Model is none" + exit 1 +fi +echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model" +export CUDA_VISIBLE_DEVICES=0 + +F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS +F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt +F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine + +num_task=2 +log_dir=./log_concurrent_tasks_${num_task} +vocoder_trt_engine_path=vocos_vocoder.plan +model_repo=./model_repo + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo "Copying f5 tts trtllm files" + python_package_path=/usr/local/lib/python3.12/dist-packages + cp -r patch/* $python_package_path/tensorrt_llm/models +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + echo "Downloading f5 tts from huggingface" + huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH + python3 ./scripts/convert_checkpoint.py \ + --timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \ + --output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model + trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \ + --max_batch_size 8 \ + --output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + echo "Exporting vocos vocoder" + onnx_vocoder_path=vocos_vocoder.onnx + python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path + bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + echo "Building triton server" + rm -r $model_repo + cp -r ./model_repo_f5_tts $model_repo + python3 fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos + cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + echo "Starting triton server" + tritonserver --model-repository=$model_repo +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + echo "Testing triton server" + python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir +fi \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py new file mode 100644 index 0000000..4e0d5d3 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py @@ -0,0 +1,243 @@ +# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License + +# Copyright (c) 2020 Shimin Zhang + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch as th +import torch.nn.functional as F +from scipy.signal import check_COLA, get_window + +support_clp_op = None +if th.__version__ >= '1.7.0': + from torch.fft import rfft as fft + support_clp_op = True +else: + from torch import rfft as fft + +class STFT(th.nn.Module): + def __init__(self, win_len=1024, win_hop=512, fft_len=1024, + enframe_mode='continue', win_type='hann', + win_sqrt=False, pad_center=True): + """ + Implement of STFT using 1D convolution and 1D transpose convolutions. + Implement of framing the signal in 2 ways, `break` and `continue`. + `break` method is a kaldi-like framing. + `continue` method is a librosa-like framing. + + More information about `perfect reconstruction`: + 1. https://ww2.mathworks.cn/help/signal/ref/stft.html + 2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html + + Args: + win_len (int): Number of points in one frame. Defaults to 1024. + win_hop (int): Number of framing stride. Defaults to 512. + fft_len (int): Number of DFT points. Defaults to 1024. + enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'. + win_type (str, optional): The type of window to create. Defaults to 'hann'. + win_sqrt (bool, optional): using square root window. Defaults to True. + pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True. + """ + super(STFT, self).__init__() + assert enframe_mode in ['break', 'continue'] + assert fft_len >= win_len + self.win_len = win_len + self.win_hop = win_hop + self.fft_len = fft_len + self.mode = enframe_mode + self.win_type = win_type + self.win_sqrt = win_sqrt + self.pad_center = pad_center + self.pad_amount = self.fft_len // 2 + + en_k, fft_k, ifft_k, ola_k = self.__init_kernel__() + self.register_buffer('en_k', en_k) + self.register_buffer('fft_k', fft_k) + self.register_buffer('ifft_k', ifft_k) + self.register_buffer('ola_k', ola_k) + + def __init_kernel__(self): + """ + Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel. + ** enframe_kernel: Using conv1d layer and identity matrix. + ** fft_kernel: Using linear layer for matrix multiplication. In fact, + enframe_kernel and fft_kernel can be combined, But for the sake of + readability, I took the two apart. + ** ifft_kernel, pinv of fft_kernel. + ** overlap-add kernel, just like enframe_kernel, but transposed. + + Returns: + tuple: four kernels. + """ + enframed_kernel = th.eye(self.fft_len)[:, None, :] + if support_clp_op: + tmp = fft(th.eye(self.fft_len)) + fft_kernel = th.stack([tmp.real, tmp.imag], dim=2) + else: + fft_kernel = fft(th.eye(self.fft_len), 1) + if self.mode == 'break': + enframed_kernel = th.eye(self.win_len)[:, None, :] + fft_kernel = fft_kernel[:self.win_len] + fft_kernel = th.cat( + (fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1) + ifft_kernel = th.pinverse(fft_kernel)[:, None, :] + window = get_window(self.win_type, self.win_len) + + self.perfect_reconstruct = check_COLA( + window, + self.win_len, + self.win_len-self.win_hop) + window = th.FloatTensor(window) + if self.mode == 'continue': + left_pad = (self.fft_len - self.win_len)//2 + right_pad = left_pad + (self.fft_len - self.win_len) % 2 + window = F.pad(window, (left_pad, right_pad)) + if self.win_sqrt: + self.padded_window = window + window = th.sqrt(window) + else: + self.padded_window = window**2 + + fft_kernel = fft_kernel.T * window + ifft_kernel = ifft_kernel * window + ola_kernel = th.eye(self.fft_len)[:self.win_len, None, :] + if self.mode == 'continue': + ola_kernel = th.eye(self.fft_len)[:, None, :self.fft_len] + return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel + + def is_perfect(self): + """ + Whether the parameters win_len, win_hop and win_sqrt + obey constants overlap-add(COLA) + + Returns: + bool: Return true if parameters obey COLA. + """ + return self.perfect_reconstruct and self.pad_center + + def transform(self, inputs, return_type='complex'): + """Take input data (audio) to STFT domain. + + Args: + inputs (tensor): Tensor of floats, with shape (num_batch, num_samples) + return_type (str, optional): return (mag, phase) when `magphase`, + return (real, imag) when `realimag` and complex(real, imag) when `complex`. + Defaults to 'complex'. + + Returns: + tuple: (mag, phase) when `magphase`, return (real, imag) when + `realimag`. Defaults to 'complex', each elements with shape + [num_batch, num_frequencies, num_frames] + """ + assert return_type in ['magphase', 'realimag', 'complex'] + if inputs.dim() == 2: + inputs = th.unsqueeze(inputs, 1) + self.num_samples = inputs.size(-1) + if self.pad_center: + inputs = F.pad( + inputs, (self.pad_amount, self.pad_amount), mode='reflect') + enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop) + outputs = th.transpose(enframe_inputs, 1, 2) + outputs = F.linear(outputs, self.fft_k) + outputs = th.transpose(outputs, 1, 2) + dim = self.fft_len//2+1 + real = outputs[:, :dim, :] + imag = outputs[:, dim:, :] + if return_type == 'realimag': + return real, imag + elif return_type == 'complex': + assert support_clp_op + return th.complex(real, imag) + else: + mags = th.sqrt(real**2+imag**2) + phase = th.atan2(imag, real) + return mags, phase + + def inverse(self, input1, input2=None, input_type='magphase'): + """Call the inverse STFT (iSTFT), given tensors produced + by the `transform` function. + + Args: + input1 (tensors): Magnitude/Real-part of STFT with shape + [num_batch, num_frequencies, num_frames] + input2 (tensors): Phase/Imag-part of STFT with shape + [num_batch, num_frequencies, num_frames] + input_type (str, optional): Mathematical meaning of input tensor's. + Defaults to 'magphase'. + + Returns: + tensors: Reconstructed audio given magnitude and phase. Of + shape [num_batch, num_samples] + """ + assert input_type in ['magphase', 'realimag'] + if input_type == 'realimag': + real, imag = None, None + if support_clp_op and th.is_complex(input1): + real, imag = input1.real, input1.imag + else: + real, imag = input1, input2 + else: + real = input1*th.cos(input2) + imag = input1*th.sin(input2) + inputs = th.cat([real, imag], dim=1) + outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop) + t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1)) + t = t.to(inputs.device) + coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop) + + num_frames = input1.size(-1) + num_samples = num_frames * self.win_hop + + rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples + + outputs = outputs[..., rm_start:rm_end] + coff = coff[..., rm_start:rm_end] + coffidx = th.where(coff > 1e-8) + outputs[coffidx] = outputs[coffidx]/(coff[coffidx]) + return outputs.squeeze(dim=1) + + def forward(self, inputs): + """Take input data (audio) to STFT domain and then back to audio. + + Args: + inputs (tensor): Tensor of floats, with shape [num_batch, num_samples] + + Returns: + tensor: Reconstructed audio given magnitude and phase. + Of shape [num_batch, num_samples] + """ + mag, phase = self.transform(inputs) + rec_wav = self.inverse(mag, phase) + return rec_wav \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py b/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py new file mode 100644 index 0000000..f86c87f --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -0,0 +1,393 @@ +import argparse +import json +import os +import re +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +import safetensors.torch +import torch + +import tensorrt_llm +from tensorrt_llm import str_dtype_to_torch +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.convert_utils import (split, split_matrix_tp, + split_qkv_bias_tp, split_qkv_tp) + +def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank): + split_v = split(v, tensor_parallel, rank, dim=1) + return split_v.contiguous() + +def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): + split_v = split(v, tensor_parallel, rank, dim=0) + return split_v.contiguous() + +FACEBOOK_DIT_NAME_MAPPING = { + '^time_embed.time_mlp.0.weight$': 'time_embed.mlp1.weight', + '^time_embed.time_mlp.0.bias$': 'time_embed.mlp1.bias', + '^time_embed.time_mlp.2.weight$': 'time_embed.mlp2.weight', + '^time_embed.time_mlp.2.bias$': 'time_embed.mlp2.bias', + '^input_embed.conv_pos_embed.conv1d.0.weight$': 'input_embed.conv_pos_embed.conv1d1.weight', + '^input_embed.conv_pos_embed.conv1d.0.bias$': 'input_embed.conv_pos_embed.conv1d1.bias', + '^input_embed.conv_pos_embed.conv1d.2.weight$': 'input_embed.conv_pos_embed.conv1d2.weight', + '^input_embed.conv_pos_embed.conv1d.2.bias$': 'input_embed.conv_pos_embed.conv1d2.bias', + '^transformer_blocks.0.attn.to_out.0.weight$': 'transformer_blocks.0.attn.to_out.weight', + '^transformer_blocks.0.attn.to_out.0.bias$': 'transformer_blocks.0.attn.to_out.bias', + '^transformer_blocks.1.attn.to_out.0.weight$': 'transformer_blocks.1.attn.to_out.weight', + '^transformer_blocks.1.attn.to_out.0.bias$': 'transformer_blocks.1.attn.to_out.bias', + '^transformer_blocks.2.attn.to_out.0.weight$': 'transformer_blocks.2.attn.to_out.weight', + '^transformer_blocks.2.attn.to_out.0.bias$': 'transformer_blocks.2.attn.to_out.bias', + '^transformer_blocks.3.attn.to_out.0.weight$': 'transformer_blocks.3.attn.to_out.weight', + '^transformer_blocks.3.attn.to_out.0.bias$': 'transformer_blocks.3.attn.to_out.bias', + '^transformer_blocks.4.attn.to_out.0.weight$': 'transformer_blocks.4.attn.to_out.weight', + '^transformer_blocks.4.attn.to_out.0.bias$': 'transformer_blocks.4.attn.to_out.bias', + '^transformer_blocks.5.attn.to_out.0.weight$': 'transformer_blocks.5.attn.to_out.weight', + '^transformer_blocks.5.attn.to_out.0.bias$': 'transformer_blocks.5.attn.to_out.bias', + '^transformer_blocks.6.attn.to_out.0.weight$': 'transformer_blocks.6.attn.to_out.weight', + '^transformer_blocks.6.attn.to_out.0.bias$': 'transformer_blocks.6.attn.to_out.bias', + '^transformer_blocks.7.attn.to_out.0.weight$': 'transformer_blocks.7.attn.to_out.weight', + '^transformer_blocks.7.attn.to_out.0.bias$': 'transformer_blocks.7.attn.to_out.bias', + '^transformer_blocks.8.attn.to_out.0.weight$': 'transformer_blocks.8.attn.to_out.weight', + '^transformer_blocks.8.attn.to_out.0.bias$': 'transformer_blocks.8.attn.to_out.bias', + '^transformer_blocks.9.attn.to_out.0.weight$': 'transformer_blocks.9.attn.to_out.weight', + '^transformer_blocks.9.attn.to_out.0.bias$': 'transformer_blocks.9.attn.to_out.bias', + '^transformer_blocks.10.attn.to_out.0.weight$': 'transformer_blocks.10.attn.to_out.weight', + '^transformer_blocks.10.attn.to_out.0.bias$': 'transformer_blocks.10.attn.to_out.bias', + '^transformer_blocks.11.attn.to_out.0.weight$': 'transformer_blocks.11.attn.to_out.weight', + '^transformer_blocks.11.attn.to_out.0.bias$': 'transformer_blocks.11.attn.to_out.bias', + '^transformer_blocks.12.attn.to_out.0.weight$': 'transformer_blocks.12.attn.to_out.weight', + '^transformer_blocks.12.attn.to_out.0.bias$': 'transformer_blocks.12.attn.to_out.bias', + '^transformer_blocks.13.attn.to_out.0.weight$': 'transformer_blocks.13.attn.to_out.weight', + '^transformer_blocks.13.attn.to_out.0.bias$': 'transformer_blocks.13.attn.to_out.bias', + '^transformer_blocks.14.attn.to_out.0.weight$': 'transformer_blocks.14.attn.to_out.weight', + '^transformer_blocks.14.attn.to_out.0.bias$': 'transformer_blocks.14.attn.to_out.bias', + '^transformer_blocks.15.attn.to_out.0.weight$': 'transformer_blocks.15.attn.to_out.weight', + '^transformer_blocks.15.attn.to_out.0.bias$': 'transformer_blocks.15.attn.to_out.bias', + '^transformer_blocks.16.attn.to_out.0.weight$': 'transformer_blocks.16.attn.to_out.weight', + '^transformer_blocks.16.attn.to_out.0.bias$': 'transformer_blocks.16.attn.to_out.bias', + '^transformer_blocks.17.attn.to_out.0.weight$': 'transformer_blocks.17.attn.to_out.weight', + '^transformer_blocks.17.attn.to_out.0.bias$': 'transformer_blocks.17.attn.to_out.bias', + '^transformer_blocks.18.attn.to_out.0.weight$': 'transformer_blocks.18.attn.to_out.weight', + '^transformer_blocks.18.attn.to_out.0.bias$': 'transformer_blocks.18.attn.to_out.bias', + '^transformer_blocks.19.attn.to_out.0.weight$': 'transformer_blocks.19.attn.to_out.weight', + '^transformer_blocks.19.attn.to_out.0.bias$': 'transformer_blocks.19.attn.to_out.bias', + '^transformer_blocks.20.attn.to_out.0.weight$': 'transformer_blocks.20.attn.to_out.weight', + '^transformer_blocks.20.attn.to_out.0.bias$': 'transformer_blocks.20.attn.to_out.bias', + '^transformer_blocks.21.attn.to_out.0.weight$': 'transformer_blocks.21.attn.to_out.weight', + '^transformer_blocks.21.attn.to_out.0.bias$': 'transformer_blocks.21.attn.to_out.bias', + '^transformer_blocks.0.ff.ff.0.0.weight$': 'transformer_blocks.0.ff.project_in.weight', + '^transformer_blocks.0.ff.ff.0.0.bias$': 'transformer_blocks.0.ff.project_in.bias', + '^transformer_blocks.0.ff.ff.2.weight$': 'transformer_blocks.0.ff.ff.weight', + '^transformer_blocks.0.ff.ff.2.bias$': 'transformer_blocks.0.ff.ff.bias', + '^transformer_blocks.1.ff.ff.0.0.weight$': 'transformer_blocks.1.ff.project_in.weight', + '^transformer_blocks.1.ff.ff.0.0.bias$': 'transformer_blocks.1.ff.project_in.bias', + '^transformer_blocks.1.ff.ff.2.weight$': 'transformer_blocks.1.ff.ff.weight', + '^transformer_blocks.1.ff.ff.2.bias$': 'transformer_blocks.1.ff.ff.bias', + '^transformer_blocks.2.ff.ff.0.0.weight$': 'transformer_blocks.2.ff.project_in.weight', + '^transformer_blocks.2.ff.ff.0.0.bias$': 'transformer_blocks.2.ff.project_in.bias', + '^transformer_blocks.2.ff.ff.2.weight$': 'transformer_blocks.2.ff.ff.weight', + '^transformer_blocks.2.ff.ff.2.bias$': 'transformer_blocks.2.ff.ff.bias', + '^transformer_blocks.3.ff.ff.0.0.weight$': 'transformer_blocks.3.ff.project_in.weight', + '^transformer_blocks.3.ff.ff.0.0.bias$': 'transformer_blocks.3.ff.project_in.bias', + '^transformer_blocks.3.ff.ff.2.weight$': 'transformer_blocks.3.ff.ff.weight', + '^transformer_blocks.3.ff.ff.2.bias$': 'transformer_blocks.3.ff.ff.bias', + '^transformer_blocks.4.ff.ff.0.0.weight$': 'transformer_blocks.4.ff.project_in.weight', + '^transformer_blocks.4.ff.ff.0.0.bias$': 'transformer_blocks.4.ff.project_in.bias', + '^transformer_blocks.4.ff.ff.2.weight$': 'transformer_blocks.4.ff.ff.weight', + '^transformer_blocks.4.ff.ff.2.bias$': 'transformer_blocks.4.ff.ff.bias', + '^transformer_blocks.5.ff.ff.0.0.weight$': 'transformer_blocks.5.ff.project_in.weight', + '^transformer_blocks.5.ff.ff.0.0.bias$': 'transformer_blocks.5.ff.project_in.bias', + '^transformer_blocks.5.ff.ff.2.weight$': 'transformer_blocks.5.ff.ff.weight', + '^transformer_blocks.5.ff.ff.2.bias$': 'transformer_blocks.5.ff.ff.bias', + '^transformer_blocks.6.ff.ff.0.0.weight$': 'transformer_blocks.6.ff.project_in.weight', + '^transformer_blocks.6.ff.ff.0.0.bias$': 'transformer_blocks.6.ff.project_in.bias', + '^transformer_blocks.6.ff.ff.2.weight$': 'transformer_blocks.6.ff.ff.weight', + '^transformer_blocks.6.ff.ff.2.bias$': 'transformer_blocks.6.ff.ff.bias', + '^transformer_blocks.7.ff.ff.0.0.weight$': 'transformer_blocks.7.ff.project_in.weight', + '^transformer_blocks.7.ff.ff.0.0.bias$': 'transformer_blocks.7.ff.project_in.bias', + '^transformer_blocks.7.ff.ff.2.weight$': 'transformer_blocks.7.ff.ff.weight', + '^transformer_blocks.7.ff.ff.2.bias$': 'transformer_blocks.7.ff.ff.bias', + '^transformer_blocks.8.ff.ff.0.0.weight$': 'transformer_blocks.8.ff.project_in.weight', + '^transformer_blocks.8.ff.ff.0.0.bias$': 'transformer_blocks.8.ff.project_in.bias', + '^transformer_blocks.8.ff.ff.2.weight$': 'transformer_blocks.8.ff.ff.weight', + '^transformer_blocks.8.ff.ff.2.bias$': 'transformer_blocks.8.ff.ff.bias', + '^transformer_blocks.9.ff.ff.0.0.weight$': 'transformer_blocks.9.ff.project_in.weight', + '^transformer_blocks.9.ff.ff.0.0.bias$': 'transformer_blocks.9.ff.project_in.bias', + '^transformer_blocks.9.ff.ff.2.weight$': 'transformer_blocks.9.ff.ff.weight', + '^transformer_blocks.9.ff.ff.2.bias$': 'transformer_blocks.9.ff.ff.bias', + '^transformer_blocks.10.ff.ff.0.0.weight$': 'transformer_blocks.10.ff.project_in.weight', + '^transformer_blocks.10.ff.ff.0.0.bias$': 'transformer_blocks.10.ff.project_in.bias', + '^transformer_blocks.10.ff.ff.2.weight$': 'transformer_blocks.10.ff.ff.weight', + '^transformer_blocks.10.ff.ff.2.bias$': 'transformer_blocks.10.ff.ff.bias', + '^transformer_blocks.11.ff.ff.0.0.weight$': 'transformer_blocks.11.ff.project_in.weight', + '^transformer_blocks.11.ff.ff.0.0.bias$': 'transformer_blocks.11.ff.project_in.bias', + '^transformer_blocks.11.ff.ff.2.weight$': 'transformer_blocks.11.ff.ff.weight', + '^transformer_blocks.11.ff.ff.2.bias$': 'transformer_blocks.11.ff.ff.bias', + '^transformer_blocks.12.ff.ff.0.0.weight$': 'transformer_blocks.12.ff.project_in.weight', + '^transformer_blocks.12.ff.ff.0.0.bias$': 'transformer_blocks.12.ff.project_in.bias', + '^transformer_blocks.12.ff.ff.2.weight$': 'transformer_blocks.12.ff.ff.weight', + '^transformer_blocks.12.ff.ff.2.bias$': 'transformer_blocks.12.ff.ff.bias', + '^transformer_blocks.13.ff.ff.0.0.weight$': 'transformer_blocks.13.ff.project_in.weight', + '^transformer_blocks.13.ff.ff.0.0.bias$': 'transformer_blocks.13.ff.project_in.bias', + '^transformer_blocks.13.ff.ff.2.weight$': 'transformer_blocks.13.ff.ff.weight', + '^transformer_blocks.13.ff.ff.2.bias$': 'transformer_blocks.13.ff.ff.bias', + '^transformer_blocks.14.ff.ff.0.0.weight$': 'transformer_blocks.14.ff.project_in.weight', + '^transformer_blocks.14.ff.ff.0.0.bias$': 'transformer_blocks.14.ff.project_in.bias', + '^transformer_blocks.14.ff.ff.2.weight$': 'transformer_blocks.14.ff.ff.weight', + '^transformer_blocks.14.ff.ff.2.bias$': 'transformer_blocks.14.ff.ff.bias', + '^transformer_blocks.15.ff.ff.0.0.weight$': 'transformer_blocks.15.ff.project_in.weight', + '^transformer_blocks.15.ff.ff.0.0.bias$': 'transformer_blocks.15.ff.project_in.bias', + '^transformer_blocks.15.ff.ff.2.weight$': 'transformer_blocks.15.ff.ff.weight', + '^transformer_blocks.15.ff.ff.2.bias$': 'transformer_blocks.15.ff.ff.bias', + '^transformer_blocks.16.ff.ff.0.0.weight$': 'transformer_blocks.16.ff.project_in.weight', + '^transformer_blocks.16.ff.ff.0.0.bias$': 'transformer_blocks.16.ff.project_in.bias', + '^transformer_blocks.16.ff.ff.2.weight$': 'transformer_blocks.16.ff.ff.weight', + '^transformer_blocks.16.ff.ff.2.bias$': 'transformer_blocks.16.ff.ff.bias', + '^transformer_blocks.17.ff.ff.0.0.weight$': 'transformer_blocks.17.ff.project_in.weight', + '^transformer_blocks.17.ff.ff.0.0.bias$': 'transformer_blocks.17.ff.project_in.bias', + '^transformer_blocks.17.ff.ff.2.weight$': 'transformer_blocks.17.ff.ff.weight', + '^transformer_blocks.17.ff.ff.2.bias$': 'transformer_blocks.17.ff.ff.bias', + '^transformer_blocks.18.ff.ff.0.0.weight$': 'transformer_blocks.18.ff.project_in.weight', + '^transformer_blocks.18.ff.ff.0.0.bias$': 'transformer_blocks.18.ff.project_in.bias', + '^transformer_blocks.18.ff.ff.2.weight$': 'transformer_blocks.18.ff.ff.weight', + '^transformer_blocks.18.ff.ff.2.bias$': 'transformer_blocks.18.ff.ff.bias', + '^transformer_blocks.19.ff.ff.0.0.weight$': 'transformer_blocks.19.ff.project_in.weight', + '^transformer_blocks.19.ff.ff.0.0.bias$': 'transformer_blocks.19.ff.project_in.bias', + '^transformer_blocks.19.ff.ff.2.weight$': 'transformer_blocks.19.ff.ff.weight', + '^transformer_blocks.19.ff.ff.2.bias$': 'transformer_blocks.19.ff.ff.bias', + '^transformer_blocks.20.ff.ff.0.0.weight$': 'transformer_blocks.20.ff.project_in.weight', + '^transformer_blocks.20.ff.ff.0.0.bias$': 'transformer_blocks.20.ff.project_in.bias', + '^transformer_blocks.20.ff.ff.2.weight$': 'transformer_blocks.20.ff.ff.weight', + '^transformer_blocks.20.ff.ff.2.bias$': 'transformer_blocks.20.ff.ff.bias', + '^transformer_blocks.21.ff.ff.0.0.weight$': 'transformer_blocks.21.ff.project_in.weight', + '^transformer_blocks.21.ff.ff.0.0.bias$': 'transformer_blocks.21.ff.project_in.bias', + '^transformer_blocks.21.ff.ff.2.weight$': 'transformer_blocks.21.ff.ff.weight', + '^transformer_blocks.21.ff.ff.2.bias$': 'transformer_blocks.21.ff.ff.bias', +} + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name', + type=str, + default="F5TTS_Base", + choices=[ + "F5TTS_Base", + ]) # TODO: support F5TTS_v1_Base + parser.add_argument('--timm_ckpt', + type=str, + default="./ckpts/model_1200000.pt") + parser.add_argument('--output_dir', + type=str, + default='./tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument('--hidden_size', + type=int, + default=1024, + help='The hidden size of DiT') + parser.add_argument('--depth', + type=int, + default=22, + help='The number of DiTBlock layers') + parser.add_argument('--num_heads', + type=int, + default=16, + help='The number of heads of attention module') + parser.add_argument('--cfg_scale', type=float, default=4.0) + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--cp_size', + type=int, + default=1, + help='Context parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'bfloat16', 'float16']) + parser.add_argument('--fp8_linear', + action='store_true', + help='Whether use FP8 for linear layers') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + args = parser.parse_args() + return args + + +def convert_timm_dit(args, mapping, dtype='float32'): + + weights = {} + tik = time.time() + torch_dtype = str_dtype_to_torch(dtype) + tensor_parallel = mapping.tp_size + + model_params = dict(torch.load(args.timm_ckpt)) + model_params = {k: v for k, v in model_params['ema_model_state_dict'].items() if k.startswith('ema_model.transformer')} + prefix = 'ema_model.transformer.' + model_params = {key[len(prefix):] if key.startswith(prefix) else key: value for key, value in model_params.items()} + + timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING + + def get_trtllm_name(timm_name): + for k, v in timm_to_trtllm_name.items(): + m = re.match(k, timm_name) + if m is not None: + if "*" in v: + v = v.replace("*", m.groups()[0]) + return v + return timm_name + + weights = dict() + for name, param in model_params.items(): + if name == 'input_embed.conv_pos_embed.conv1d.0.weight' or \ + name == 'input_embed.conv_pos_embed.conv1d.2.weight': + weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1) + else: + weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype) + + assert len(weights) == len(model_params) + + # new_prefix = 'f5_transformer.' + new_prefix = '' + weights = {new_prefix+key:value for key, value in weights.items()} + import math + scale_factor = math.pow(64, -0.25) + for k, v in weights.items(): + if re.match('^transformer_blocks.*.attn.to_k.weight$', k): + weights[k] *= scale_factor + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, + tensor_parallel, mapping.tp_rank) + + elif re.match('^transformer_blocks.*.attn.to_k.bias$', k): + weights[k] *= scale_factor + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, + tensor_parallel, mapping.tp_rank) + + elif re.match('^transformer_blocks.*.attn.to_q.weight$', k): + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, + tensor_parallel, mapping.tp_rank) + weights[k] *= scale_factor + + elif re.match('^transformer_blocks.*.attn.to_q.bias$', k): + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, + tensor_parallel, mapping.tp_rank) + weights[k] *= scale_factor + + elif re.match('^transformer_blocks.*.attn.to_v.weight$', k): + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, + tensor_parallel, mapping.tp_rank) + + elif re.match('^transformer_blocks.*.attn.to_v.bias$', k): + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, + tensor_parallel, mapping.tp_rank) + + elif re.match('^transformer_blocks.*.attn.to_out.weight$', k): + weights[k] = split_matrix_tp(v, + tensor_parallel, + mapping.tp_rank, + dim=1) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Weights loaded. Total time: {t}') + return weights + + +def save_config(args): + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + config = { + 'architecture': "F5TTS", + 'dtype': args.dtype, + 'hidden_size': 1024, + 'num_hidden_layers': 22, + 'num_attention_heads': 16, + 'dim_head': 64, + 'dropout': 0.1, + 'ff_mult': 2, + 'mel_dim': 100, + 'text_num_embeds': 256, + 'text_dim': 512, + 'conv_layers': 4, + 'long_skip_connection': False, + 'mapping': { + 'world_size': args.cp_size * args.tp_size * args.pp_size, + 'cp_size': args.cp_size, + 'tp_size': args.tp_size, + 'pp_size': args.pp_size, + } + } + if args.fp8_linear: + config['quantization'] = { + 'quant_algo': "FP8", + # TODO: add support for exclude modules. + # 'exclude_modules': "*final_layer*", + } + + with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + + +def covert_and_save(args, rank): + if rank == 0: + save_config(args) + + mapping = Mapping(world_size=args.cp_size * args.tp_size * args.pp_size, + rank=rank, + cp_size=args.cp_size, + tp_size=args.tp_size, + pp_size=args.pp_size) + + weights = convert_timm_dit(args, mapping, dtype=args.dtype) + + safetensors.torch.save_file( + weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + +def main(): + args = parse_arguments() + world_size = args.cp_size * args.tp_size * args.pp_size + + assert args.pp_size == 1, "PP is not supported yet." + + tik = time.time() + if args.timm_ckpt is None: + return + print("start execute") + execute(args.workers, [covert_and_save] * world_size, args) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py new file mode 100644 index 0000000..02fcf74 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download + +from conv_stft import STFT +from vocos import Vocos +import argparse +opset_version = 17 + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--vocoder", + type=str, + default="vocos", + choices=["vocos", "bigvgan"], + help="Vocoder to export", + ) + parser.add_argument( + "--output-path", + type=str, + default="./vocos_vocoder.onnx", + help="Output path", + ) + return parser.parse_args() + +class ISTFTHead(nn.Module): + + def __init__(self, n_fft: int, hop_length: int): + super().__init__() + self.out = None + self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft) + + def forward(self, x: torch.Tensor): + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) + real = mag * torch.cos(p) + imag = mag * torch.sin(p) + audio = self.stft.inverse(input1=real, input2=imag, input_type='realimag') + return audio + +class VocosVocoder(nn.Module): + def __init__(self, vocos_vocoder): + super(VocosVocoder, self).__init__() + self.vocos_vocoder = vocos_vocoder + istft_head_out = self.vocos_vocoder.head.out + n_fft = self.vocos_vocoder.head.istft.n_fft + hop_length = self.vocos_vocoder.head.istft.hop_length + istft_head_for_export = ISTFTHead(n_fft, hop_length) + istft_head_for_export.out = istft_head_out + self.vocos_vocoder.head = istft_head_for_export + + def forward(self, mel): + waveform = self.vocos_vocoder.decode(mel) + return waveform + +def export_VocosVocoder(vocos_vocoder, output_path, verbose): + vocos_vocoder = VocosVocoder(vocos_vocoder).cuda() + vocos_vocoder.eval() + + dummy_batch_size = 8 + dummy_input_length = 500 + + dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda() + + with torch.no_grad(): + dummy_waveform = vocos_vocoder( + mel=dummy_mel + ) + print(dummy_waveform.shape) + + dummy_input = (dummy_mel) + + torch.onnx.export( + vocos_vocoder, + dummy_input, + output_path, + opset_version=opset_version, + do_constant_folding=True, + input_names=["mel"], + output_names=["waveform"], + dynamic_axes={ + "mel": {0: "batch_size", 2: "input_length"}, + "waveform": {0: "batch_size", 1: "output_length"}, + }, + verbose=verbose) + + print("Exported to {}".format(output_path)) + +def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None): + if vocoder_name == "vocos": + # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) + if is_local: + print(f"Load vocos from local path {local_path}") + config_path = f"{local_path}/config.yaml" + model_path = f"{local_path}/pytorch_model.bin" + else: + print("Download Vocos from huggingface charactr/vocos-mel-24khz") + repo_id = "charactr/vocos-mel-24khz" + config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") + model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") + vocoder = Vocos.from_hparams(config_path) + state_dict = torch.load(model_path, map_location="cpu", weights_only=True) + vocoder.load_state_dict(state_dict) + vocoder = vocoder.eval().to(device) + elif vocoder_name == "bigvgan": + try: + from third_party.BigVGAN import bigvgan + except ImportError: + print("You need to follow the README to init submodule and change the BigVGAN source code.") + if is_local: + """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main""" + vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) + else: + local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir) + vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) + + vocoder.remove_weight_norm() + vocoder = vocoder.eval().to(device) + return vocoder + +if __name__ == "__main__": + args = get_args() + vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None) + if args.vocoder == "vocos": + export_VocosVocoder(vocoder, args.output_path, verbose=False) \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh new file mode 100644 index 0000000..2702275 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocos_trt.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +TRTEXEC="/usr/src/tensorrt/bin/trtexec" + +ONNX_PATH=$1 +ENGINE_PATH=$2 +echo "ONNX_PATH: $ONNX_PATH" +echo "ENGINE_PATH: $ENGINE_PATH" +PRECISION="fp32" + + +MIN_BATCH_SIZE=1 +OPT_BATCH_SIZE=1 +MAX_BATCH_SIZE=8 + +MIN_INPUT_LENGTH=1 +OPT_INPUT_LENGTH=1000 +MAX_INPUT_LENGTH=3000 + +MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}" +MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}" +MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}" + +${TRTEXEC} \ + --minShapes="mel:${MEL_MIN_SHAPE}" \ + --optShapes="mel:${MEL_OPT_SHAPE}" \ + --maxShapes="mel:${MEL_MAX_SHAPE}" \ + --onnx=${ONNX_PATH} \ + --saveEngine=${ENGINE_PATH} + From 5b178397e032ed31d95768a34ed2d696d426d37f Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Apr 2025 02:34:28 +0000 Subject: [PATCH 2/7] remove unused codes --- src/f5_tts/runtime/triton_trtllm/README.md | 7 +- .../runtime/triton_trtllm/client_grpc.py | 99 ++-- .../f5_tts/1/f5_tts_trtllm.py | 204 ++++---- .../model_repo_f5_tts/f5_tts/1/model.py | 102 ++-- .../triton_trtllm/patch/f5tts/model.py | 298 +++++------ .../triton_trtllm/patch/f5tts/modules.py | 373 +++++-------- src/f5_tts/runtime/triton_trtllm/run.sh | 2 +- .../triton_trtllm/scripts/conv_stft.py | 90 ++-- .../scripts/convert_checkpoint.py | 488 ++++++++---------- .../scripts/export_vocoder_to_onnx.py | 49 +- .../{ => scripts}/fill_template.py | 12 +- 11 files changed, 776 insertions(+), 948 deletions(-) rename src/f5_tts/runtime/triton_trtllm/{ => scripts}/fill_template.py (65%) diff --git a/src/f5_tts/runtime/triton_trtllm/README.md b/src/f5_tts/runtime/triton_trtllm/README.md index 991ef59..ca69e5f 100644 --- a/src/f5_tts/runtime/triton_trtllm/README.md +++ b/src/f5_tts/runtime/triton_trtllm/README.md @@ -23,13 +23,13 @@ docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper). ```sh -bash build_server.sh +bash run.sh 0 4 F5TTS_Base ``` ### Benchmark using Dataset ```sh num_task=2 -python3 client.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts +python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts ``` ### Benchmark Results @@ -40,5 +40,4 @@ Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs. | F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394| ### Credits -1. [F5-TTS](https://github.com/SWivid/F5-TTS) -2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) \ No newline at end of file +1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/client_grpc.py b/src/f5_tts/runtime/triton_trtllm/client_grpc.py index 2f92ab6..0d3b154 100644 --- a/src/f5_tts/runtime/triton_trtllm/client_grpc.py +++ b/src/f5_tts/runtime/triton_trtllm/client_grpc.py @@ -57,7 +57,6 @@ import tritonclient.grpc.aio as grpcclient from tritonclient.utils import np_to_triton_dtype - def write_triton_stats(stats, summary_file): with open(summary_file, "w") as summary_f: model_stats = stats["model_stats"] @@ -66,12 +65,8 @@ def write_triton_stats(stats, summary_file): "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" ) summary_f.write("To learn more about the log, please refer to: \n") - summary_f.write( - "1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n" - ) - summary_f.write( - "2. https://github.com/triton-inference-server/server/issues/5374 \n\n" - ) + summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n") + summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n") summary_f.write( "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" ) @@ -92,9 +87,7 @@ def write_triton_stats(stats, summary_file): total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 - total_output_time_s = ( - int(model_inference_stats["compute_output"]["ns"]) / 1e9 - ) + total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9 summary_f.write( f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa ) @@ -105,30 +98,23 @@ def write_triton_stats(stats, summary_file): compute_output = batch["compute_output"] compute_infer = batch["compute_infer"] batch_count = int(compute_infer["count"]) - assert ( - compute_infer["count"] - == compute_output["count"] - == compute_input["count"] - ) + assert compute_infer["count"] == compute_output["count"] == compute_input["count"] compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 compute_input_time_ms = int(compute_input["ns"]) / 1e6 compute_output_time_ms = int(compute_output["ns"]) / 1e6 summary_f.write( - f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms/batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms/batch_count/batch_size:.2f} ms \n" # noqa + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa + ) + summary_f.write( + f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa + ) + summary_f.write( + f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa ) - # summary_f.write( - # f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa - # ) - # summary_f.write( - # f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa - # ) - def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--server-addr", @@ -191,9 +177,7 @@ def get_args(): "--model-name", type=str, default="f5_tts", - choices=[ - "f5_tts", "spark_tts" - ], + choices=["f5_tts", "spark_tts"], help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", ) @@ -246,10 +230,12 @@ def load_audio(wav_path, target_sample_rate=16000): waveform, sample_rate = sf.read(wav_path) if sample_rate != target_sample_rate: from scipy.signal import resample + num_samples = int(len(waveform) * (target_sample_rate / sample_rate)) waveform = resample(waveform, num_samples) return waveform, target_sample_rate + async def send( manifest_item_list: list, name: str, @@ -261,7 +247,6 @@ async def send( audio_save_dir: str = "./", ): total_duration = 0.0 - results = [] latency_data = [] task_id = int(name[5:]) @@ -282,9 +267,7 @@ async def send( samples = np.zeros( ( 1, - padding_duration - * sample_rate - * ((int(duration) // padding_duration) + 1), + padding_duration * sample_rate * ((int(duration) // padding_duration) + 1), ), dtype=np.float32, ) @@ -292,18 +275,14 @@ async def send( samples[0, : len(waveform)] = waveform else: samples = waveform - + samples = samples.reshape(1, -1).astype(np.float32) inputs = [ - protocol_client.InferInput( - "reference_wav", samples.shape, np_to_triton_dtype(samples.dtype) - ), - protocol_client.InferInput( - "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype) - ), + protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)), + protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)), protocol_client.InferInput("reference_text", [1, 1], "BYTES"), - protocol_client.InferInput("target_text", [1, 1], "BYTES") + protocol_client.InferInput("target_text", [1, 1], "BYTES"), ] inputs[0].set_data_from_numpy(samples) inputs[1].set_data_from_numpy(lengths) @@ -320,17 +299,13 @@ async def send( sequence_id = 100000000 + i + task_id * 10 start = time.time() - response = await triton_client.infer( - model_name, inputs, request_id=str(sequence_id), outputs=outputs - ) + response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs) audio = response.as_numpy("waveform").reshape(-1) end = time.time() - start - audio_save_path = os.path.join( - audio_save_dir, f"{item['target_audio_path']}.wav" - ) + audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") sf.write(audio_save_path, audio, 16000, "PCM_16") latency_data.append((end, estimated_target_duration)) @@ -338,6 +313,7 @@ async def send( return total_duration, latency_data + def load_manifests(manifest_path): with open(manifest_path, "r") as f: manifest_list = [] @@ -353,7 +329,7 @@ def load_manifests(manifest_path): "audio_filepath": prompt_wav, "reference_text": prompt_text, "target_text": gt_text, - "target_audio_path": utt + "target_audio_path": utt, } ) return manifest_list @@ -362,9 +338,7 @@ def load_manifests(manifest_path): def split_data(data, k): n = len(data) if n < k: - print( - f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}." - ) + print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.") k = n quotient = n // k @@ -395,12 +369,12 @@ async def main(): args.num_tasks = 1 args.log_interval = 1 manifest_item_list = [ - { - "reference_text": args.reference_text, - "target_text": args.target_text, - "audio_filepath": args.reference_audio, - "target_audio_path": "test", - } + { + "reference_text": args.reference_text, + "target_text": args.target_text, + "audio_filepath": args.reference_audio, + "target_audio_path": "test", + } ] elif args.huggingface_dataset: import datasets @@ -422,7 +396,7 @@ async def main(): ) else: manifest_item_list = load_manifests(args.manifest_path) - + args.num_tasks = min(args.num_tasks, len(manifest_item_list)) manifest_item_list = split_data(manifest_item_list, args.num_tasks) @@ -449,7 +423,6 @@ async def main(): end_time = time.time() elapsed = end_time - start_time - total_duration = 0.0 latency_data = [] for ans in ans_list: @@ -460,8 +433,8 @@ async def main(): s = f"RTF: {rtf:.4f}\n" s += f"total_duration: {total_duration:.3f} seconds\n" - s += f"({total_duration/3600:.2f} hours)\n" - s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n" + s += f"({total_duration / 3600:.2f} hours)\n" + s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n" latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 @@ -480,12 +453,14 @@ async def main(): name = args.split_name with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: f.write(s) - + stats = await triton_client.get_inference_statistics(model_name="", as_json=True) write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True) with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: json.dump(metadata, f, indent=4) + + if __name__ == "__main__": asyncio.run(main()) diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py index cb7f2d2..ecd12a6 100644 --- a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py @@ -1,23 +1,21 @@ - import tensorrt as trt import os import math import time -from typing import List, Dict, Union, Optional +from typing import List, Optional from functools import wraps import tensorrt_llm from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch from tensorrt_llm.logger import logger -from tensorrt_llm.plugin.plugin import CustomAllReduceHelper from tensorrt_llm.runtime.session import Session import torch import torch.nn as nn import torch.nn.functional as F -def remove_tensor_padding(input_tensor, - input_tensor_lengths=None): + +def remove_tensor_padding(input_tensor, input_tensor_lengths=None): # Audio tensor case: batch, seq_len, feature_len # position_ids case: batch, seq_len assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor" @@ -33,8 +31,9 @@ def remove_tensor_padding(input_tensor, output_tensor = torch.cat(valid_sequences, dim=0).contiguous() return output_tensor + class TextEmbedding(nn.Module): - def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2, precompute_max_pos = 4096): + def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False) @@ -47,7 +46,7 @@ class TextEmbedding(nn.Module): text = text[:, :text_pad_cut_off_index] text = self.text_embed(text) - text = text + self.freqs_cis[:text.shape[1], :] + text = text + self.freqs_cis[: text.shape[1], :] for block in self.text_blocks: text = block(text) # padding text to the original length @@ -67,7 +66,7 @@ class GRN(nn.Module): Gx = torch.norm(x, p=2, dim=1, keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x - + class ConvNeXtV2Block(nn.Module): def __init__( @@ -78,7 +77,9 @@ class ConvNeXtV2Block(nn.Module): ): super().__init__() padding = (dilation * (7 - 1)) // 2 - self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() @@ -98,7 +99,7 @@ class ConvNeXtV2Block(nn.Module): return residual + x -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.): +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ @@ -111,6 +112,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca freqs_sin = torch.sin(freqs) # imaginary part return torch.cat([freqs_cos, freqs_sin], dim=-1) + def load_checkpoint(ckpt_path, use_ema=True): checkpoint = torch.load(ckpt_path, weights_only=True) if use_ema: @@ -123,12 +125,12 @@ def load_checkpoint(ckpt_path, use_ema=True): text_embed_dict = {} for key in dict_state.keys(): # transformer.text_embed.text_embed.weight -> text_embed.weight - if 'text_embed' in key: - text_embed_dict[key.replace('transformer.text_embed.', '')] = dict_state[key] + if "text_embed" in key: + text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key] return text_embed_dict -class F5TTS(object): +class F5TTS(object): def __init__( self, config, @@ -138,21 +140,16 @@ class F5TTS(object): model_path: Optional[str] = None, vocab_size: Optional[int] = None, ): - self.dtype = config['pretrained_config']['dtype'] + self.dtype = config["pretrained_config"]["dtype"] rank = tensorrt_llm.mpi_rank() - world_size = config['pretrained_config']['mapping']['world_size'] - cp_size = config['pretrained_config']['mapping']['cp_size'] - tp_size = config['pretrained_config']['mapping']['tp_size'] - pp_size = config['pretrained_config']['mapping']['pp_size'] + world_size = config["pretrained_config"]["mapping"]["world_size"] + cp_size = config["pretrained_config"]["mapping"]["cp_size"] + tp_size = config["pretrained_config"]["mapping"]["tp_size"] + pp_size = config["pretrained_config"]["mapping"]["pp_size"] assert pp_size == 1 self.mapping = tensorrt_llm.Mapping( - world_size=world_size, - rank=rank, - cp_size=cp_size, - tp_size=tp_size, - pp_size=1, - gpus_per_node=1 + world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1 ) local_rank = rank % self.mapping.gpus_per_node @@ -166,7 +163,7 @@ class F5TTS(object): torch.cuda.set_stream(self.stream) engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine") - logger.info(f'Loading engine from {engine_file}') + logger.info(f"Loading engine from {engine_file}") with open(engine_file, "rb") as f: engine_buffer = f.read() @@ -180,14 +177,10 @@ class F5TTS(object): self.outputs = {} self.buffer_allocated = False - expected_tensor_names = ['noise', 'cond', 'time', 'rope_cos', 'rope_sin', 'input_lengths', 'denoised'] + expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"] - found_tensor_names = [ - self.session.engine.get_tensor_name(i) - for i in range(self.session.engine.num_io_tensors) - ] - if not self.debug_mode and set(expected_tensor_names) != set( - found_tensor_names): + found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)] + if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names): logger.error( f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}" ) @@ -196,18 +189,13 @@ class F5TTS(object): ) logger.error(f"Expected tensor names: {expected_tensor_names}") logger.error(f"Found tensor names: {found_tensor_names}") - raise RuntimeError( - "Tensor names in engine are not the same as expected.") + raise RuntimeError("Tensor names in engine are not the same as expected.") if self.debug_mode: - self.debug_tensors = list( - set(found_tensor_names) - set(expected_tensor_names)) - + self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names)) + self.max_mel_len = 4096 self.text_embedding = TextEmbedding( - text_num_embeds=vocab_size, - text_dim=512, - conv_layers=4, - precompute_max_pos=self.max_mel_len + text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len ).to(self.device) self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True) @@ -217,7 +205,7 @@ class F5TTS(object): self.win_length = 1024 self.hop_length = 256 self.n_mel_channels = 100 - #self.max_mel_len = 3000 + # self.max_mel_len = 3000 self.head_dim = 64 self.base_rescale_factor = 1.0 self.interpolation_factor = 1.0 @@ -251,20 +239,16 @@ class F5TTS(object): def _setup(self, batch_size, seq_len): for i in range(self.session.engine.num_io_tensors): name = self.session.engine.get_tensor_name(i) - if self.session.engine.get_tensor_mode( - name) == trt.TensorIOMode.OUTPUT: + if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: shape = list(self.session.engine.get_tensor_shape(name)) shape[0] = batch_size shape[1] = seq_len - self.outputs[name] = torch.empty(shape, - dtype=self._tensor_dtype(name), - device=self.device) + self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device) self.buffer_allocated = True def cuda_stream_guard(func): - """Sync external stream and set current stream to the one bound to the session. Reset on exit. - """ + """Sync external stream and set current stream to the one bound to the session. Reset on exit.""" @wraps(func) def wrapper(self, *args, **kwargs): @@ -281,48 +265,56 @@ class F5TTS(object): return wrapper @cuda_stream_guard - def forward(self, noise: torch.Tensor, cond: torch.Tensor, - time_expand: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, - input_lengths: torch.Tensor, delta_t: torch.Tensor, use_perf: bool = False): + def forward( + self, + noise: torch.Tensor, + cond: torch.Tensor, + time_expand: torch.Tensor, + rope_cos: torch.Tensor, + rope_sin: torch.Tensor, + input_lengths: torch.Tensor, + delta_t: torch.Tensor, + use_perf: bool = False, + ): if use_perf: torch.cuda.nvtx.range_push("flow matching") cfg_strength = 2.0 batch_size = noise.shape[0] half_batch = batch_size // 2 noise_half = noise[:half_batch] # Store the initial half of noise - + input_type = str_dtype_to_torch(self.dtype) - + # Keep a copy of the initial tensors cond = cond.to(input_type) - rope_cos = rope_cos.to(input_type) + rope_cos = rope_cos.to(input_type) rope_sin = rope_sin.to(input_type) input_lengths = input_lengths.to(str_dtype_to_torch("int32")) - + # Instead of iteratively updating noise within a single model context, # we'll do a single forward pass for each iteration with fresh context setup - for i in range(self.nfe_steps): + for i in range(self.nfe_steps): # Re-setup the buffers for clean execution self._setup(batch_size, noise.shape[1]) if not self.buffer_allocated: - raise RuntimeError('Buffer not allocated, please call setup first!') - + raise RuntimeError("Buffer not allocated, please call setup first!") + # Re-create combined noises for this iteration current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type) - + # Get time step for this iteration current_time = time_expand[:, i].to(input_type) - + # Create fresh input dictionary for this iteration current_inputs = { - 'noise': current_noise, - 'cond': cond, - 'time': current_time, - 'rope_cos': rope_cos, - 'rope_sin': rope_sin, - 'input_lengths': input_lengths, + "noise": current_noise, + "cond": cond, + "time": current_time, + "rope_cos": rope_cos, + "rope_sin": rope_sin, + "input_lengths": input_lengths, } - + # Update inputs and set shapes self.inputs.clear() # Clear previous inputs self.inputs.update(**current_inputs) @@ -330,19 +322,18 @@ class F5TTS(object): if use_perf: torch.cuda.nvtx.range_push(f"execute {i}") - ok = self.session.run(self.inputs, self.outputs, - self.stream.cuda_stream) + ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream) assert ok, "Failed to execute model" # self.session.context.execute_async_v3(self.stream.cuda_stream) if use_perf: - torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_pop() # Process results t_scale = delta_t[i].unsqueeze(0).to(input_type) - + # Extract predictions pred_cond = self.outputs["denoised"][:half_batch] pred_uncond = self.outputs["denoised"][half_batch:] - + # Apply classifier-free guidance with safeguards guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength # Calculate update for noise @@ -351,63 +342,72 @@ class F5TTS(object): torch.cuda.nvtx.range_pop() return noise_half - def sample(self, text_pad_sequence: torch.Tensor, ref_mel_batch: torch.Tensor, - ref_mel_len_batch: torch.Tensor, estimated_reference_target_mel_len: List[int], remove_input_padding: bool = False, use_perf: bool = False): + def sample( + self, + text_pad_sequence: torch.Tensor, + ref_mel_batch: torch.Tensor, + ref_mel_len_batch: torch.Tensor, + estimated_reference_target_mel_len: List[int], + remove_input_padding: bool = False, + use_perf: bool = False, + ): if use_perf: torch.cuda.nvtx.range_push("text embedding") batch = text_pad_sequence.shape[0] max_seq_len = ref_mel_batch.shape[1] text_pad_sequence_drop = torch.cat( - (text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), - dim=0 + (text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0 ) - + text_embedding_drop_list = [] - for i in range(batch+1): + for i in range(batch + 1): text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device))) text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0) - + text_embedding = text_embedding_drop_condition[:-1] # text_embedding_drop B,T,C batch should be the same text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1) noise = torch.randn_like(ref_mel_batch).to(self.device) - rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1) + rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1) rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1) cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1) cat_mel_text_drop = torch.cat( - (torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device), text_embedding_drop), - dim=-1 + ( + torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device), + text_embedding_drop, + ), + dim=-1, ) time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous() - + # Convert estimated_reference_target_mel_len to tensor input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32) - + # combine above along the batch dimension inputs = { - 'noise': torch.cat((noise, noise), dim=0).contiguous(), - 'cond': torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(), - 'time_expand': time_expand, - 'rope_cos': torch.cat((rope_cos, rope_cos), dim=0).contiguous(), - 'rope_sin': torch.cat((rope_sin, rope_sin), dim=0).contiguous(), - 'input_lengths': torch.cat((input_lengths, input_lengths), dim=0).contiguous(), - 'delta_t': self.delta_t + "noise": torch.cat((noise, noise), dim=0).contiguous(), + "cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(), + "time_expand": time_expand, + "rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(), + "rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(), + "input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(), + "delta_t": self.delta_t, } if use_perf and remove_input_padding: torch.cuda.nvtx.range_push("remove input padding") if remove_input_padding: - max_seq_len = inputs['cond'].shape[1] - inputs['noise'] = remove_tensor_padding(inputs['noise'], inputs['input_lengths']) - inputs['cond'] = remove_tensor_padding(inputs['cond'], inputs['input_lengths']) + max_seq_len = inputs["cond"].shape[1] + inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"]) + inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"]) # for time_expand, convert from B,D to B,T,D by repeat - inputs['time_expand'] = inputs['time_expand'].unsqueeze(1).repeat(1, max_seq_len, 1, 1) - inputs['time_expand'] = remove_tensor_padding(inputs['time_expand'], inputs['input_lengths']) - inputs['rope_cos'] = remove_tensor_padding(inputs['rope_cos'], inputs['input_lengths']) - inputs['rope_sin'] = remove_tensor_padding(inputs['rope_sin'], inputs['input_lengths']) + inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1) + inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"]) + inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"]) + inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"]) if use_perf and remove_input_padding: torch.cuda.nvtx.range_pop() for key in inputs: @@ -423,9 +423,9 @@ class F5TTS(object): denoised_list = [] start_idx = 0 for i in range(batch): - denoised_list.append(denoised[start_idx:start_idx+inputs['input_lengths'][i]]) - start_idx += inputs['input_lengths'][i] + denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]]) + start_idx += inputs["input_lengths"][i] if use_perf and remove_input_padding: torch.cuda.nvtx.range_pop() return denoised_list, cost_time - return denoised, cost_time \ No newline at end of file + return denoised, cost_time diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py index 8337185..a0ca9d3 100644 --- a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py @@ -25,7 +25,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json import torch -from torch import nn from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F from torch.utils.dlpack import from_dlpack, to_dlpack @@ -33,10 +32,9 @@ import torchaudio import jieba import triton_python_backend_utils as pb_utils from pypinyin import Style, lazy_pinyin -import math import os from f5_tts_trtllm import F5TTS -torch.manual_seed(0) + def get_tokenizer(vocab_file_path: str): """ @@ -55,6 +53,7 @@ def get_tokenizer(vocab_file_path: str): vocab_size = len(vocab_char_map) return vocab_char_map, vocab_size + def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): final_reference_target_texts_list = [] custom_trans = str.maketrans( @@ -73,9 +72,7 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"": char_list.append(" ") char_list.extend(seg) - elif polyphone and seg_byte_len == 3 * len( - seg - ): # if pure east asian characters + elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True) for i, c in enumerate(seg): if is_chinese(c): @@ -87,32 +84,29 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True): char_list.extend(c) elif is_chinese(c): char_list.append(" ") - char_list.extend( - lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True) - ) + char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)) else: char_list.append(c) final_reference_target_texts_list.append(char_list) return final_reference_target_texts_list + def list_str_to_idx( text: list[str] | list[list[str]], vocab_char_map: dict[str, int], # {char: idx} padding_value=-1, ): # noqa: F722 - list_idx_tensors = [ - torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text - ] # pinyin or char style - # text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True) + list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style return list_idx_tensors + class TritonPythonModel: def initialize(self, args): self.use_perf = True self.device = torch.device("cuda") self.target_audio_sample_rate = 24000 - self.target_rms = 0.15 # target rms for audio + self.target_rms = 0.15 # target rms for audio self.n_fft = 1024 self.win_length = 1024 self.hop_length = 256 @@ -120,8 +114,7 @@ class TritonPythonModel: self.max_mel_len = 3000 self.head_dim = 64 - - parameters = json.loads(args['model_config'])['parameters'] + parameters = json.loads(args["model_config"])["parameters"] for key, value in parameters.items(): parameters[key] = value["string_value"] @@ -130,10 +123,16 @@ class TritonPythonModel: self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate) self.tllm_model_dir = parameters["tllm_model_dir"] - config_file = os.path.join(self.tllm_model_dir, 'config.json') + config_file = os.path.join(self.tllm_model_dir, "config.json") with open(config_file) as f: config = json.load(f) - self.model = F5TTS(config, debug_mode=False, tllm_model_dir=self.tllm_model_dir, model_path=parameters["model_path"], vocab_size=self.vocab_size) + self.model = F5TTS( + config, + debug_mode=False, + tllm_model_dir=self.tllm_model_dir, + model_path=parameters["model_path"], + vocab_size=self.vocab_size, + ) self.vocoder = parameters["vocoder"] assert self.vocoder in ["vocos", "bigvgan"] @@ -161,44 +160,44 @@ class TritonPythonModel: def forward_vocoder(self, mel): mel = mel.to(torch.float32).contiguous() input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)) - + inference_request = pb_utils.InferenceRequest( - model_name='vocoder', - requested_output_names=['waveform'], - inputs=[input_tensor_0]) + model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0] + ) inference_response = inference_request.exec() if inference_response.has_error(): raise pb_utils.TritonModelException(inference_response.error().message()) else: - waveform = pb_utils.get_output_tensor_by_name(inference_response, - 'waveform') + waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform") waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() return waveform - + def execute(self, requests): - reference_text_list, target_text_list, reference_target_texts_list, reference_wavs_tensor, estimated_reference_target_mel_len, reference_mel_len = [], [], [], [], [], [] - max_wav_len = 0 + ( + reference_text_list, + target_text_list, + reference_target_texts_list, + estimated_reference_target_mel_len, + reference_mel_len, + ) = [], [], [], [], [] mel_features_list = [] if self.use_perf: torch.cuda.nvtx.range_push("preprocess") for request in requests: wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav") - wav_lens = pb_utils.get_input_tensor_by_name( - request, "reference_wav_len") - - reference_text = pb_utils.get_input_tensor_by_name( - request, "reference_text").as_numpy() - reference_text = reference_text[0][0].decode('utf-8') + wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode("utf-8") reference_text_list.append(reference_text) - target_text = pb_utils.get_input_tensor_by_name( - request, "target_text").as_numpy() - target_text = target_text[0][0].decode('utf-8') + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode("utf-8") target_text_list.append(target_text) text = reference_text + target_text reference_target_texts_list.append(text) - + wav = from_dlpack(wav_tensor.to_dlpack()) wav_len = from_dlpack(wav_lens.to_dlpack()) wav_len = wav_len.squeeze() @@ -217,31 +216,36 @@ class TritonPythonModel: if self.use_perf: torch.cuda.nvtx.range_pop() mel_features_list.append(mel_features) - + reference_mel_len.append(mel_features.shape[1]) - estimated_reference_target_mel_len.append(int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))) - + estimated_reference_target_mel_len.append( + int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text))) + ) + max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len) batch = len(requests) - print(f"The current batch is {batch}") mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device) for i, mel in enumerate(mel_features_list): - mel_features[i, :mel.shape[1], :] = mel + mel_features[i, : mel.shape[1], :] = mel reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device) - + pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map) - + for i, item in enumerate(text_pad_sequence): - text_pad_sequence[i] = F.pad(item, (0, estimated_reference_target_mel_len[i] - len(item)), mode='constant', value=-1) - text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS + text_pad_sequence[i] = F.pad( + item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1 + ) + text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device) - text_pad_sequence = F.pad(text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode='constant', value=-1) + text_pad_sequence = F.pad( + text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1 + ) if self.use_perf: torch.cuda.nvtx.range_pop() - + denoised, cost_time = self.model.sample( text_pad_sequence, mel_features, @@ -262,7 +266,7 @@ class TritonPythonModel: rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < self.target_rms: audio = audio * self.target_rms / rms - + audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) inference_response = pb_utils.InferenceResponse(output_tensors=[audio]) responses.append(inference_response) diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py index 45bfec5..e0b830b 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -10,96 +10,39 @@ d - dimension from __future__ import annotations import sys import os -current_file_path = os.path.abspath(__file__) -parent_dir = os.path.dirname(current_file_path) -sys.path.append(parent_dir) -import math -import numpy as np -import torch -from torch import nn + import tensorrt as trt from collections import OrderedDict -from ..._utils import str_dtype_to_trt, trt_dtype_to_str, trt_dtype_to_np +from ..._utils import str_dtype_to_trt from ...plugin import current_all_reduce_helper from ..modeling_utils import PretrainedConfig, PretrainedModel -from ...functional import (Tensor, allgather, arange, chunk, concat, constant, - cos, exp, expand, shape, silu, sin, slice, split, - unsqueeze, squeeze, cast) +from ...functional import Tensor, concat from ...module import Module, ModuleList from tensorrt_llm._common import default_net from ...layers import Linear from .modules import ( TimestepEmbedding, - # ConvNeXtV2Block, ConvPositionEmbedding, DiTBlock, AdaLayerNormZero_Final, - # precompute_freqs_cis, get_pos_embed_indices, ) -# Text embedding -# class TextEmbedding(Module): -# def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2): -# super().__init__() -# self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token +current_file_path = os.path.abspath(__file__) +parent_dir = os.path.dirname(current_file_path) +sys.path.append(parent_dir) -# if conv_layers > 0: -# self.extra_modeling = True -# self.precompute_max_pos = 4096 # ~44s of 24khz audio -# self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False) -# self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]) -# else: -# self.extra_modeling = False - -# def forward(self, text: int['b nt'], seq_len): -# text = self.text_embed(text) # b n -> b n d - -# # possible extra modeling -# if self.extra_modeling: -# # sinus pos emb -# pos_idx = get_pos_embed_indices(torch.zeros(1, dtype=torch.int32), seq_len, max_pos=self.precompute_max_pos) -# # convnextv2 blocks -# text = self.text_blocks(text + self.freqs_cis[pos_idx]) - -# return text class InputEmbedding(Module): def __init__(self, mel_dim, text_dim, out_dim): super().__init__() self.proj = Linear(mel_dim * 2 + text_dim, out_dim) - self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim) + self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False): + def forward(self, x, cond, drop_audio_cond=False): # if drop_audio_cond: # cfg for cond audio - x = self.proj(concat([x, cond], dim = -1)) + x = self.proj(concat([x, cond], dim=-1)) return self.conv_pos_embed(x) + x - -# Transformer backbone using DiT blocks -# class F5TTS(PretrainedModel): -# def __init__(self, config: PretrainedConfig): -# super().__init__(config) -# self.f5_transformer = DiT_transformer(config) -# self.dtype = str_dtype_to_trt(config.dtype) -# self.cfg_strength = 2 - -# def forward(self, -# noise: float['b n d'], # nosied input audio -# cond: float['b n d'], # masked cond audio -# cond_drop: float['b n d'], -# time: float['b n'], # time step -# rope_cos: float['b n d'], -# rope_sin: float['b n d'], -# t_scale: float['b'], -# mask: bool['b n'] | None = None): - -# pred = self.f5_transformer(x = noise, cond = cond, cond_drop = cond_drop, time = time, rope_cos = rope_cos, rope_sin = rope_sin, mask = mask) -# pred, pred1 = chunk(pred, 2, dim = 0), chunk works only for static tensor -# # cfg_strength = constant(np.array([self.cfg_strength], dtype = np.float32)).cast(noise.dtype) -# # noise = noise + (pred_cond + (pred_cond - pred_uncond) * cfg_strength) * t_scale -# noise.mark_output('denoised', self.dtype) -# return noise - class F5TTS(PretrainedModel): @@ -107,148 +50,187 @@ class F5TTS(PretrainedModel): super().__init__(config) self.dtype = str_dtype_to_trt(config.dtype) - self.time_embed = TimestepEmbedding(config.hidden_size) # √ - if config.text_dim is None: - text_dim = config.mel_dim - self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) + self.time_embed = TimestepEmbedding(config.hidden_size) # √ + text_dim = config.mel_dim if config.text_dim is None else config.text_dim + self.input_embed = InputEmbedding(config.mel_dim, text_dim, config.hidden_size) self.dim = config.hidden_size self.depth = config.num_hidden_layers self.transformer_blocks = ModuleList( [ DiTBlock( - dim = self.dim, - heads = config.num_attention_heads, - dim_head = config.dim_head, - ff_mult = config.ff_mult, - dropout = config.dropout + dim=self.dim, + heads=config.num_attention_heads, + dim_head=config.dim_head, + ff_mult=config.ff_mult, + dropout=config.dropout, ) for _ in range(self.depth) ] ) - + self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation self.proj_out = Linear(config.hidden_size, config.mel_dim) def forward( - self, - noise: float['b n d'], # nosied input audio - cond: float['b n d'], # masked cond audio - time: float['b n'], # time step - rope_cos: float['b n d'] , - rope_sin: float['b n d'], - input_lengths: int['b'], - scale = 1.0 + self, + noise, # nosied input audio + cond, # masked cond audio + time, # time step + rope_cos, + rope_sin, + input_lengths, + scale=1.0, ): t = self.time_embed(time) x = self.input_embed(noise, cond) - # x = concat([self.input_embed(x, cond), self.input_embed(x, cond_drop)], dim = 0) - for block in self.transformer_blocks: - x = block(x, t, rope_cos = rope_cos, rope_sin = rope_sin, input_lengths=input_lengths, scale = scale) + x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) denoise = self.proj_out(self.norm_out(x, t)) - denoise.mark_output('denoised', self.dtype) + denoise.mark_output("denoised", self.dtype) return denoise def prepare_inputs(self, **kwargs): - max_batch_size = kwargs['max_batch_size'] + max_batch_size = kwargs["max_batch_size"] batch_size_range = [2, 2, max_batch_size] mel_size = 100 max_seq_len = 3000 num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size] hidden_size = 512 concat_feature_dim = mel_size + hidden_size - freq_embed_dim=256 + freq_embed_dim = 256 head_dim = 64 mapping = self.config.mapping if mapping.tp_size > 1: current_all_reduce_helper().set_workspace_tensor(mapping, 1) if default_net().plugin_config.remove_input_padding: noise = Tensor( - name='noise', + name="noise", dtype=self.dtype, shape=[-1, mel_size], - dim_range=OrderedDict([ - ('num_frames', [num_frames_range]), - ('n_mels', [mel_size]), - ])) + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("n_mels", [mel_size]), + ] + ), + ) cond = Tensor( - name='cond', + name="cond", dtype=self.dtype, shape=[-1, concat_feature_dim], - dim_range=OrderedDict([ - ('num_frames', [num_frames_range]), - ('embeded_length', [concat_feature_dim]), - ])) - time = Tensor(name='time', - dtype=self.dtype, - shape=[-1, freq_embed_dim], - dim_range=OrderedDict([ - ('num_frames', [num_frames_range]), - ('freq_dim', [freq_embed_dim]), - ])) - rope_cos = Tensor(name='rope_cos', - dtype=self.dtype, - shape=[-1, head_dim], - dim_range=OrderedDict([ - ('num_frames', [num_frames_range]), - ('head_dim', [head_dim]), - ])) - rope_sin = Tensor(name='rope_sin', - dtype=self.dtype, - shape=[-1, head_dim], - dim_range=OrderedDict([ - ('num_frames', [num_frames_range]), - ('head_dim', [head_dim]), - ])) + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("embeded_length", [concat_feature_dim]), + ] + ), + ) + time = Tensor( + name="time", + dtype=self.dtype, + shape=[-1, freq_embed_dim], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("freq_dim", [freq_embed_dim]), + ] + ), + ) + rope_cos = Tensor( + name="rope_cos", + dtype=self.dtype, + shape=[-1, head_dim], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("head_dim", [head_dim]), + ] + ), + ) + rope_sin = Tensor( + name="rope_sin", + dtype=self.dtype, + shape=[-1, head_dim], + dim_range=OrderedDict( + [ + ("num_frames", [num_frames_range]), + ("head_dim", [head_dim]), + ] + ), + ) else: noise = Tensor( - name='noise', + name="noise", dtype=self.dtype, shape=[-1, -1, mel_size], - dim_range=OrderedDict([ - ('batch_size', [batch_size_range]), - ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), - ('n_mels', [mel_size]), - ])) + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("n_mels", [mel_size]), + ] + ), + ) cond = Tensor( - name='cond', + name="cond", dtype=self.dtype, shape=[-1, -1, concat_feature_dim], - dim_range=OrderedDict([ - ('batch_size', [batch_size_range]), - ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), - ('embeded_length', [concat_feature_dim]), - ])) - print(233333333333333333333333333333333333333333333333333, batch_size_range) - time = Tensor(name='time', - dtype=self.dtype, - shape=[-1, freq_embed_dim], - dim_range=OrderedDict([ - ('batch_size', [batch_size_range]), - ('freq_dim', [freq_embed_dim]), - ])) - rope_cos = Tensor(name='rope_cos', - dtype=self.dtype, - shape=[-1, -1, head_dim], - dim_range=OrderedDict([ - ('batch_size', [batch_size_range]), - ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), - ('head_dim', [head_dim]), - ])) - rope_sin = Tensor(name='rope_sin', - dtype=self.dtype, - shape=[-1, -1, head_dim], - dim_range=OrderedDict([ - ('batch_size', [batch_size_range]), - ('max_duratuion', [[100, max_seq_len // 2, max_seq_len]]), - ('head_dim', [head_dim]), - ])) + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("embeded_length", [concat_feature_dim]), + ] + ), + ) + time = Tensor( + name="time", + dtype=self.dtype, + shape=[-1, freq_embed_dim], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("freq_dim", [freq_embed_dim]), + ] + ), + ) + rope_cos = Tensor( + name="rope_cos", + dtype=self.dtype, + shape=[-1, -1, head_dim], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("head_dim", [head_dim]), + ] + ), + ) + rope_sin = Tensor( + name="rope_sin", + dtype=self.dtype, + shape=[-1, -1, head_dim], + dim_range=OrderedDict( + [ + ("batch_size", [batch_size_range]), + ("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]), + ("head_dim", [head_dim]), + ] + ), + ) input_lengths = Tensor( name="input_lengths", dtype=trt.int32, shape=[-1], dim_range=OrderedDict([("batch_size", [batch_size_range])]), ) - return {'noise': noise, 'cond': cond, 'time': time, 'rope_cos': rope_cos, 'rope_sin': rope_sin, 'input_lengths': input_lengths} \ No newline at end of file + return { + "noise": noise, + "cond": cond, + "time": time, + "rope_cos": rope_cos, + "rope_sin": rope_sin, + "input_lengths": input_lengths, + } diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py index 896e3d7..5bfd5a0 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py @@ -5,31 +5,33 @@ from typing import Optional import torch import torch.nn.functional as F -# import torchaudio -# from librosa.filters import mel as librosa_mel_fn -from torch import nn + import numpy as np -import tensorrt as trt from tensorrt_llm._common import default_net -from ..._utils import (fp32_array, int32_array, is_same_dtype, set_obj_attrs, - trt_dtype_to_np, trt_dtype_to_str,str_dtype_to_trt) -from ...functional import (Tensor, allgather, arange, chunk, concat, constant, - cos, exp, expand, shape, silu, sin, slice, split, permute, expand_mask, expand_dims_like, - unsqueeze, matmul, softmax, where, RopeEmbeddingUtils, minimum, repeat_interleave, squeeze, cast, gelu) +from ..._utils import trt_dtype_to_np, str_dtype_to_trt +from ...functional import ( + Tensor, + chunk, + concat, + constant, + expand, + shape, + silu, + slice, + permute, + expand_mask, + expand_dims_like, + unsqueeze, + matmul, + softmax, + squeeze, + cast, + gelu, +) from ...functional import expand_dims, view, bert_attention -from ...layers import MLP, BertAttention, Conv2d, LayerNorm, Linear, Conv1d, Mish, embedding, RowLinear, ColumnLinear -from ...module import Module, ModuleList +from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear +from ...module import Module -# class GRN(Module): -# def __init__(self, dim): -# super().__init__() -# self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) -# self.beta = nn.Parameter(torch.zeros(1, 1, dim)) - -# def forward(self, x): -# Gx = torch.norm(x, p=2, dim=1, keepdim=True) -# Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) -# return self.gamma * (x * Nx) + self.beta + x class FeedForward(Module): def __init__(self, dim, dim_out=None, mult=4, dropout=0.0): @@ -43,59 +45,6 @@ class FeedForward(Module): def forward(self, x): return self.ff(gelu(self.project_in(x))) -# class ConvNeXtV2Block(Module): -# def __init__( -# self, -# dim: int, -# intermediate_dim: int, -# dilation: int = 1, -# ): -# super().__init__() -# padding = (dilation * (7 - 1)) // 2 -# self.dwconv = nn.Conv1d( -# dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation -# ) # depthwise conv -# self.norm = nn.LayerNorm(dim, eps=1e-6) -# self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers -# self.act = nn.GELU() -# self.grn = GRN(intermediate_dim) -# self.pwconv2 = nn.Linear(intermediate_dim, dim) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# residual = x -# x = x.transpose(1, 2) # b n d -> b d n -# x = self.dwconv(x) -# x = x.transpose(1, 2) # b d n -> b n d -# x = self.norm(x) -# x = self.pwconv1(x) -# x = self.act(x) -# x = self.grn(x) -# x = self.pwconv2(x) -# return residual + x - -# def get_pos_embed_indices(start, length, max_pos, scale=1.0): -# # length = length if isinstance(length, int) else length.max() -# scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar -# pos = ( -# unsqueeze(start, 1) -# + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long() -# ) -# # avoid extra long error. -# pos = torch.where(pos < max_pos, pos, max_pos - 1) -# return pos - -# def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0): -# # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning -# # has some connection to NTK literature -# # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ -# # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py -# theta *= theta_rescale_factor ** (dim / (dim - 2)) -# freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) -# t = torch.arange(end, device=freqs.device) # type: ignore -# freqs = torch.outer(t, freqs).float() # type: ignore -# freqs_cos = torch.cos(freqs) # real part -# freqs_sin = torch.sin(freqs) # imaginary part -# return torch.cat([freqs_cos, freqs_sin], dim=-1) class AdaLayerNormZero(Module): def __init__(self, dim): @@ -108,14 +57,14 @@ class AdaLayerNormZero(Module): emb = self.linear(silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1) x = self.norm(x) - ones = constant(np.ones(1, dtype = np.float32)).cast(x.dtype) + ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) if default_net().plugin_config.remove_input_padding: x = x * (ones + scale_msa) + shift_msa else: x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1) - # x = x * unsqueeze((ones + scale_msa), 1) + unsqueeze(shift_msa, 1) return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + class AdaLayerNormZero_Final(Module): def __init__(self, dim): super().__init__() @@ -127,9 +76,7 @@ class AdaLayerNormZero_Final(Module): def forward(self, x, emb): emb = self.linear(silu(emb)) scale, shift = chunk(emb, 2, dim=1) - # scale ----> (1, 1024) - # x ----> (1, -1, 1024) - ones = constant(np.ones(1, dtype = np.float32)).cast(x.dtype) + ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) if default_net().plugin_config.remove_input_padding: x = self.norm(x) * (ones + scale) + shift else: @@ -137,6 +84,7 @@ class AdaLayerNormZero_Final(Module): x = x + unsqueeze(shift, 1) return x + class ConvPositionEmbedding(Module): def __init__(self, dim, kernel_size=31, groups=16): super().__init__() @@ -145,10 +93,7 @@ class ConvPositionEmbedding(Module): self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) self.mish = Mish() - def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 - # if mask is not None: - # mask = mask[..., None] - # x = x.masked_fill(~mask, 0.0) + def forward(self, x, mask): # noqa: F722 if default_net().plugin_config.remove_input_padding: x = unsqueeze(x, 0) x = permute(x, [0, 2, 1]) @@ -156,10 +101,9 @@ class ConvPositionEmbedding(Module): out = permute(x, [0, 2, 1]) if default_net().plugin_config.remove_input_padding: out = squeeze(out, 0) - # if mask is not None: - # out = out.masked_fill(~mask, 0.0) return out + class Attention(Module): def __init__( self, @@ -168,8 +112,8 @@ class Attention(Module): heads: int = 16, dim_head: int = 64, dropout: float = 0.0, - context_dim: Optional[int] = None, # if not None -> joint attention - context_pre_only = None, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, ): super().__init__() @@ -178,7 +122,7 @@ class Attention(Module): self.processor = processor - self.dim = dim # hidden_size + self.dim = dim # hidden_size self.heads = heads self.inner_dim = dim_head * heads self.dropout = dropout @@ -187,28 +131,33 @@ class Attention(Module): self.context_pre_only = context_pre_only self.tp_size = 1 self.num_attention_heads = heads // self.tp_size - self.num_attention_kv_heads = heads // self.tp_size # 8 - self.dtype = str_dtype_to_trt('float32') + self.num_attention_kv_heads = heads // self.tp_size # 8 + self.dtype = str_dtype_to_trt("float32") self.attention_hidden_size = self.attention_head_size * self.num_attention_heads - # self.to_q = Linear(dim, self.inner_dim) - self.to_q = ColumnLinear(dim, - self.tp_size * self.num_attention_heads *self.attention_head_size, - bias=True, - dtype=self.dtype, - tp_group=None, - tp_size=self.tp_size) - self.to_k = ColumnLinear(dim, - self.tp_size * self.num_attention_heads *self.attention_head_size, - bias=True, - dtype=self.dtype, - tp_group=None, - tp_size=self.tp_size) - self.to_v = ColumnLinear(dim, - self.tp_size * self.num_attention_heads *self.attention_head_size, - bias=True, - dtype=self.dtype, - tp_group=None, - tp_size=self.tp_size) + self.to_q = ColumnLinear( + dim, + self.tp_size * self.num_attention_heads * self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) + self.to_k = ColumnLinear( + dim, + self.tp_size * self.num_attention_heads * self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) + self.to_v = ColumnLinear( + dim, + self.tp_size * self.num_attention_heads * self.attention_head_size, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) if self.context_dim is not None: self.to_k_c = Linear(context_dim, self.inner_dim) @@ -216,112 +165,100 @@ class Attention(Module): if self.context_pre_only is not None: self.to_q_c = Linear(context_dim, self.inner_dim) - # self.to_out = Linear(self.inner_dim, dim) - self.to_out = RowLinear(self.tp_size * self.num_attention_heads * - self.attention_head_size, - dim, - bias=True, - dtype=self.dtype, - tp_group=None, - tp_size=self.tp_size) - # self.to_out.append(Dropout(dropout)) + self.to_out = RowLinear( + self.tp_size * self.num_attention_heads * self.attention_head_size, + dim, + bias=True, + dtype=self.dtype, + tp_group=None, + tp_size=self.tp_size, + ) if self.context_pre_only is not None and not self.context_pre_only: self.to_out_c = Linear(self.inner_dim, dim) - def forward( - self, - x: float['b n d'], # noised input x - rope_cos, - rope_sin, - input_lengths, - c: float['b n d'] = None, # context c - scale = 1.0, - rope=None, - c_rope=None, # rotary position embedding for c + self, + x, # noised input x + rope_cos, + rope_sin, + input_lengths, + c=None, # context c + scale=1.0, + rope=None, + c_rope=None, # rotary position embedding for c ) -> torch.Tensor: if c is not None: return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope) else: - return self.processor(self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale) + return self.processor( + self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale + ) + def rotate_every_two_3dim(tensor: Tensor) -> Tensor: - shape_tensor = concat([ - shape(tensor, i) / 2 if i == (tensor.ndim() - - 1) else shape(tensor, i) - for i in range(tensor.ndim()) - ]) + shape_tensor = concat( + [shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())] + ) if default_net().plugin_config.remove_input_padding: assert tensor.ndim() == 2 x1 = slice(tensor, [0, 0], shape_tensor, [1, 2]) x2 = slice(tensor, [0, 1], shape_tensor, [1, 2]) x1 = expand_dims(x1, 2) x2 = expand_dims(x2, 2) - zero = constant( - np.ascontiguousarray( - np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) x2 = zero - x2 x = concat([x2, x1], 2) - out = view( - x, concat([shape(x, 0), - shape(x, 1)*2])) + out = view(x, concat([shape(x, 0), shape(x, 1) * 2])) else: assert tensor.ndim() == 3 - x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2]) - x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2]) + x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2]) + x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2]) x1 = expand_dims(x1, 3) x2 = expand_dims(x2, 3) - zero = constant( - np.ascontiguousarray( - np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) + zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) x2 = zero - x2 x = concat([x2, x1], 3) - out = view( - x, concat([shape(x, 0), - shape(x, 1), - shape(x, 2)*2])) + out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2])) return out + def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin): if default_net().plugin_config.remove_input_padding: - rot_dim = shape(rope_cos, -1) #64 - new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64) + rot_dim = shape(rope_cos, -1) # 64 + new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64) x_ = slice(x, [0, 0], new_t_shape, [1, 1]) end_dim = shape(x, -1) - shape(rope_cos, -1) - new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960) + new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960) x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1]) - out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim = -1) + out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1) else: - rot_dim = shape(rope_cos, 2) #64 - new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64) + rot_dim = shape(rope_cos, 2) # 64 + new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64) x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1]) end_dim = shape(x, 2) - shape(rope_cos, 2) - new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960) + new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960) x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1]) - out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim = -1) - # t -> (2,-1,1024) freqs -> (-1,64) + out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1) return out + class AttnProcessor: def __init__(self): pass def __call__( self, - attn: Attention, - x: float['b n d'], # noised input x + attn, + x, # noised input x rope_cos, rope_sin, input_lengths, - scale = 1.0, + scale=1.0, rope=None, ) -> torch.FloatTensor: - - - query = attn.to_q(x) key = attn.to_k(x) value = attn.to_v(x) @@ -331,7 +268,6 @@ class AttnProcessor: # attention inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads norm_factor = math.sqrt(attn.attention_head_size) q_scaling = 1.0 / norm_factor mask = None @@ -341,58 +277,50 @@ class AttnProcessor: seq_len_2d = concat([1, N]) max_position_embeddings = 4096 # create position ids - position_ids_buffer = constant( - np.expand_dims( - np.arange(max_position_embeddings).astype(np.int32), - 0)) - tmp_position_ids = slice(position_ids_buffer, - starts=[0, 0], - sizes=seq_len_2d) - tmp_position_ids = expand(tmp_position_ids, concat([B, N])) #BxL - tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 - tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) #BxL + position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0)) + tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d) + tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1 + tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL mask = tmp_position_ids < tmp_input_lengths # BxL - mask = mask.cast('int32') + mask = mask.cast("int32") if default_net().plugin_config.bert_attention_plugin: - qkv = concat([query, key, value], dim = -1) + qkv = concat([query, key, value], dim=-1) # TRT plugin mode assert input_lengths is not None if default_net().plugin_config.remove_input_padding: - qkv = qkv.view( - concat([-1, 3 * inner_dim])) + qkv = qkv.view(concat([-1, 3 * inner_dim])) max_input_length = constant( - np.zeros([ - 2048, - ], dtype=np.int32)) - print("============================================================================") + np.zeros( + [ + 2048, + ], + dtype=np.int32, + ) + ) else: max_input_length = None - print("******************************************************************************************************") - context = bert_attention(qkv, - input_lengths, - attn.num_attention_heads, - attn.attention_head_size, - q_scaling=q_scaling, - max_input_length=max_input_length) + context = bert_attention( + qkv, + input_lengths, + attn.num_attention_heads, + attn.attention_head_size, + q_scaling=q_scaling, + max_input_length=max_input_length, + ) else: assert not default_net().plugin_config.remove_input_padding def transpose_for_scores(x): - new_x_shape = concat([ - shape(x, 0), - shape(x, 1), attn.num_attention_heads, attn.attention_head_size - ]) + new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) y = x.view(new_x_shape) y = y.transpose(1, 2) return y def transpose_for_scores_k(x): - new_x_shape = concat([ - shape(x, 0), - shape(x, 1), attn.num_attention_heads, attn.attention_head_size - ]) + new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size]) y = x.view(new_x_shape) y = y.permute([0, 2, 3, 1]) @@ -408,43 +336,40 @@ class AttnProcessor: attention_mask = expand_mask(mask, shape(query, 2)) attention_mask = cast(attention_mask, attention_scores.dtype) attention_scores = attention_scores + attention_mask - + attention_probs = softmax(attention_scores, dim=-1) context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2) - context = context.view( - concat([ - shape(context, 0), - shape(context, 1), attn.attention_hidden_size - ])) + context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size])) context = attn.to_out(context) if mask is not None: mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1])) mask = expand_dims_like(mask, context) mask = cast(mask, context.dtype) - # mask = where(mask ==0, 0.0, 1.0) context = context * mask return context + # DiT Block class DiTBlock(Module): - - def __init__(self, dim, heads, dim_head, ff_mult = 2, dropout = 0.1): + def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1): super().__init__() self.attn_norm = AdaLayerNormZero(dim) self.attn = Attention( - processor = AttnProcessor(), - dim = dim, - heads = heads, - dim_head = dim_head, - dropout = dropout, - ) + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout) + self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout) - def forward(self, x, t, rope_cos, rope_sin, input_lengths, scale = 1.0, rope = ModuleNotFoundError): # x: noised input, t: time embedding + def forward( + self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError + ): # x: noised input, t: time embedding # pre-norm & modulation for attention input norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) # attention @@ -456,7 +381,7 @@ class DiTBlock(Module): x = x + gate_msa * attn_output else: x = x + unsqueeze(gate_msa, 1) * attn_output - ones = constant(np.ones(1, dtype = np.float32)).cast(x.dtype) + ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype) if default_net().plugin_config.remove_input_padding: norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp else: @@ -470,20 +395,6 @@ class DiTBlock(Module): return x -# class SinusPositionEmbedding(Module): -# def __init__(self, dim): -# super().__init__() -# self.dim = dim - -# def forward(self, x, scale=1000): -# half_dim = self.dim // 2 -# emb = math.log(10000) / (half_dim - 1) -# emb = exp(arange(start=0, end=half_dim, dtype=trt_dtype_to_str(trt.float32)) * - emb) -# emb = scale * unsqueeze(x, 1) * unsqueeze(emb, 0) -# emb = concat([cos(emb), sin(emb)], dim=-1) -# emb = emb.cast(x.dtype) -# assert self.dim % 2 == 0 -# return emb class TimestepEmbedding(Module): def __init__(self, dim, freq_embed_dim=256, dtype=None): @@ -492,7 +403,7 @@ class TimestepEmbedding(Module): self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype) self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype) - def forward(self, timestep: float["b n"]): # noqa: F821 + def forward(self, timestep): t_freq = self.mlp1(timestep) t_freq = silu(t_freq) t_emb = self.mlp2(t_freq) diff --git a/src/f5_tts/runtime/triton_trtllm/run.sh b/src/f5_tts/runtime/triton_trtllm/run.sh index ed38497..bf5526d 100644 --- a/src/f5_tts/runtime/triton_trtllm/run.sh +++ b/src/f5_tts/runtime/triton_trtllm/run.sh @@ -47,7 +47,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then echo "Building triton server" rm -r $model_repo cp -r ./model_repo_f5_tts $model_repo - python3 fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos + python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan fi diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py index 4e0d5d3..563ba84 100644 --- a/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +++ b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py @@ -41,16 +41,25 @@ import torch.nn.functional as F from scipy.signal import check_COLA, get_window support_clp_op = None -if th.__version__ >= '1.7.0': +if th.__version__ >= "1.7.0": from torch.fft import rfft as fft + support_clp_op = True else: from torch import rfft as fft + class STFT(th.nn.Module): - def __init__(self, win_len=1024, win_hop=512, fft_len=1024, - enframe_mode='continue', win_type='hann', - win_sqrt=False, pad_center=True): + def __init__( + self, + win_len=1024, + win_hop=512, + fft_len=1024, + enframe_mode="continue", + win_type="hann", + win_sqrt=False, + pad_center=True, + ): """ Implement of STFT using 1D convolution and 1D transpose convolutions. Implement of framing the signal in 2 ways, `break` and `continue`. @@ -71,7 +80,7 @@ class STFT(th.nn.Module): pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True. """ super(STFT, self).__init__() - assert enframe_mode in ['break', 'continue'] + assert enframe_mode in ["break", "continue"] assert fft_len >= win_len self.win_len = win_len self.win_hop = win_hop @@ -83,21 +92,21 @@ class STFT(th.nn.Module): self.pad_amount = self.fft_len // 2 en_k, fft_k, ifft_k, ola_k = self.__init_kernel__() - self.register_buffer('en_k', en_k) - self.register_buffer('fft_k', fft_k) - self.register_buffer('ifft_k', ifft_k) - self.register_buffer('ola_k', ola_k) + self.register_buffer("en_k", en_k) + self.register_buffer("fft_k", fft_k) + self.register_buffer("ifft_k", ifft_k) + self.register_buffer("ola_k", ola_k) def __init_kernel__(self): """ Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel. ** enframe_kernel: Using conv1d layer and identity matrix. ** fft_kernel: Using linear layer for matrix multiplication. In fact, - enframe_kernel and fft_kernel can be combined, But for the sake of + enframe_kernel and fft_kernel can be combined, But for the sake of readability, I took the two apart. ** ifft_kernel, pinv of fft_kernel. ** overlap-add kernel, just like enframe_kernel, but transposed. - + Returns: tuple: four kernels. """ @@ -107,21 +116,17 @@ class STFT(th.nn.Module): fft_kernel = th.stack([tmp.real, tmp.imag], dim=2) else: fft_kernel = fft(th.eye(self.fft_len), 1) - if self.mode == 'break': + if self.mode == "break": enframed_kernel = th.eye(self.win_len)[:, None, :] - fft_kernel = fft_kernel[:self.win_len] - fft_kernel = th.cat( - (fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1) + fft_kernel = fft_kernel[: self.win_len] + fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1) ifft_kernel = th.pinverse(fft_kernel)[:, None, :] window = get_window(self.win_type, self.win_len) - self.perfect_reconstruct = check_COLA( - window, - self.win_len, - self.win_len-self.win_hop) + self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop) window = th.FloatTensor(window) - if self.mode == 'continue': - left_pad = (self.fft_len - self.win_len)//2 + if self.mode == "continue": + left_pad = (self.fft_len - self.win_len) // 2 right_pad = left_pad + (self.fft_len - self.win_len) % 2 window = F.pad(window, (left_pad, right_pad)) if self.win_sqrt: @@ -132,9 +137,9 @@ class STFT(th.nn.Module): fft_kernel = fft_kernel.T * window ifft_kernel = ifft_kernel * window - ola_kernel = th.eye(self.fft_len)[:self.win_len, None, :] - if self.mode == 'continue': - ola_kernel = th.eye(self.fft_len)[:, None, :self.fft_len] + ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :] + if self.mode == "continue": + ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len] return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel def is_perfect(self): @@ -147,7 +152,7 @@ class STFT(th.nn.Module): """ return self.perfect_reconstruct and self.pad_center - def transform(self, inputs, return_type='complex'): + def transform(self, inputs, return_type="complex"): """Take input data (audio) to STFT domain. Args: @@ -158,39 +163,38 @@ class STFT(th.nn.Module): Returns: tuple: (mag, phase) when `magphase`, return (real, imag) when - `realimag`. Defaults to 'complex', each elements with shape + `realimag`. Defaults to 'complex', each elements with shape [num_batch, num_frequencies, num_frames] """ - assert return_type in ['magphase', 'realimag', 'complex'] + assert return_type in ["magphase", "realimag", "complex"] if inputs.dim() == 2: inputs = th.unsqueeze(inputs, 1) self.num_samples = inputs.size(-1) if self.pad_center: - inputs = F.pad( - inputs, (self.pad_amount, self.pad_amount), mode='reflect') + inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect") enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop) outputs = th.transpose(enframe_inputs, 1, 2) outputs = F.linear(outputs, self.fft_k) outputs = th.transpose(outputs, 1, 2) - dim = self.fft_len//2+1 + dim = self.fft_len // 2 + 1 real = outputs[:, :dim, :] imag = outputs[:, dim:, :] - if return_type == 'realimag': + if return_type == "realimag": return real, imag - elif return_type == 'complex': + elif return_type == "complex": assert support_clp_op return th.complex(real, imag) else: - mags = th.sqrt(real**2+imag**2) + mags = th.sqrt(real**2 + imag**2) phase = th.atan2(imag, real) return mags, phase - def inverse(self, input1, input2=None, input_type='magphase'): - """Call the inverse STFT (iSTFT), given tensors produced + def inverse(self, input1, input2=None, input_type="magphase"): + """Call the inverse STFT (iSTFT), given tensors produced by the `transform` function. Args: - input1 (tensors): Magnitude/Real-part of STFT with shape + input1 (tensors): Magnitude/Real-part of STFT with shape [num_batch, num_frequencies, num_frames] input2 (tensors): Phase/Imag-part of STFT with shape [num_batch, num_frequencies, num_frames] @@ -201,16 +205,16 @@ class STFT(th.nn.Module): tensors: Reconstructed audio given magnitude and phase. Of shape [num_batch, num_samples] """ - assert input_type in ['magphase', 'realimag'] - if input_type == 'realimag': + assert input_type in ["magphase", "realimag"] + if input_type == "realimag": real, imag = None, None if support_clp_op and th.is_complex(input1): real, imag = input1.real, input1.imag else: real, imag = input1, input2 else: - real = input1*th.cos(input2) - imag = input1*th.sin(input2) + real = input1 * th.cos(input2) + imag = input1 * th.sin(input2) inputs = th.cat([real, imag], dim=1) outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop) t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1)) @@ -221,11 +225,11 @@ class STFT(th.nn.Module): num_samples = num_frames * self.win_hop rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples - + outputs = outputs[..., rm_start:rm_end] coff = coff[..., rm_start:rm_end] coffidx = th.where(coff > 1e-8) - outputs[coffidx] = outputs[coffidx]/(coff[coffidx]) + outputs[coffidx] = outputs[coffidx] / (coff[coffidx]) return outputs.squeeze(dim=1) def forward(self, inputs): @@ -240,4 +244,4 @@ class STFT(th.nn.Module): """ mag, phase = self.transform(inputs) rec_wav = self.inverse(mag, phase) - return rec_wav \ No newline at end of file + return rec_wav diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py b/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py index f86c87f..22dad65 100644 --- a/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py +++ b/src/f5_tts/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -9,231 +9,207 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import safetensors.torch import torch -import tensorrt_llm from tensorrt_llm import str_dtype_to_torch from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.convert_utils import (split, split_matrix_tp, - split_qkv_bias_tp, split_qkv_tp) +from tensorrt_llm.models.convert_utils import split, split_matrix_tp + def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank): split_v = split(v, tensor_parallel, rank, dim=1) return split_v.contiguous() + def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): split_v = split(v, tensor_parallel, rank, dim=0) return split_v.contiguous() + FACEBOOK_DIT_NAME_MAPPING = { - '^time_embed.time_mlp.0.weight$': 'time_embed.mlp1.weight', - '^time_embed.time_mlp.0.bias$': 'time_embed.mlp1.bias', - '^time_embed.time_mlp.2.weight$': 'time_embed.mlp2.weight', - '^time_embed.time_mlp.2.bias$': 'time_embed.mlp2.bias', - '^input_embed.conv_pos_embed.conv1d.0.weight$': 'input_embed.conv_pos_embed.conv1d1.weight', - '^input_embed.conv_pos_embed.conv1d.0.bias$': 'input_embed.conv_pos_embed.conv1d1.bias', - '^input_embed.conv_pos_embed.conv1d.2.weight$': 'input_embed.conv_pos_embed.conv1d2.weight', - '^input_embed.conv_pos_embed.conv1d.2.bias$': 'input_embed.conv_pos_embed.conv1d2.bias', - '^transformer_blocks.0.attn.to_out.0.weight$': 'transformer_blocks.0.attn.to_out.weight', - '^transformer_blocks.0.attn.to_out.0.bias$': 'transformer_blocks.0.attn.to_out.bias', - '^transformer_blocks.1.attn.to_out.0.weight$': 'transformer_blocks.1.attn.to_out.weight', - '^transformer_blocks.1.attn.to_out.0.bias$': 'transformer_blocks.1.attn.to_out.bias', - '^transformer_blocks.2.attn.to_out.0.weight$': 'transformer_blocks.2.attn.to_out.weight', - '^transformer_blocks.2.attn.to_out.0.bias$': 'transformer_blocks.2.attn.to_out.bias', - '^transformer_blocks.3.attn.to_out.0.weight$': 'transformer_blocks.3.attn.to_out.weight', - '^transformer_blocks.3.attn.to_out.0.bias$': 'transformer_blocks.3.attn.to_out.bias', - '^transformer_blocks.4.attn.to_out.0.weight$': 'transformer_blocks.4.attn.to_out.weight', - '^transformer_blocks.4.attn.to_out.0.bias$': 'transformer_blocks.4.attn.to_out.bias', - '^transformer_blocks.5.attn.to_out.0.weight$': 'transformer_blocks.5.attn.to_out.weight', - '^transformer_blocks.5.attn.to_out.0.bias$': 'transformer_blocks.5.attn.to_out.bias', - '^transformer_blocks.6.attn.to_out.0.weight$': 'transformer_blocks.6.attn.to_out.weight', - '^transformer_blocks.6.attn.to_out.0.bias$': 'transformer_blocks.6.attn.to_out.bias', - '^transformer_blocks.7.attn.to_out.0.weight$': 'transformer_blocks.7.attn.to_out.weight', - '^transformer_blocks.7.attn.to_out.0.bias$': 'transformer_blocks.7.attn.to_out.bias', - '^transformer_blocks.8.attn.to_out.0.weight$': 'transformer_blocks.8.attn.to_out.weight', - '^transformer_blocks.8.attn.to_out.0.bias$': 'transformer_blocks.8.attn.to_out.bias', - '^transformer_blocks.9.attn.to_out.0.weight$': 'transformer_blocks.9.attn.to_out.weight', - '^transformer_blocks.9.attn.to_out.0.bias$': 'transformer_blocks.9.attn.to_out.bias', - '^transformer_blocks.10.attn.to_out.0.weight$': 'transformer_blocks.10.attn.to_out.weight', - '^transformer_blocks.10.attn.to_out.0.bias$': 'transformer_blocks.10.attn.to_out.bias', - '^transformer_blocks.11.attn.to_out.0.weight$': 'transformer_blocks.11.attn.to_out.weight', - '^transformer_blocks.11.attn.to_out.0.bias$': 'transformer_blocks.11.attn.to_out.bias', - '^transformer_blocks.12.attn.to_out.0.weight$': 'transformer_blocks.12.attn.to_out.weight', - '^transformer_blocks.12.attn.to_out.0.bias$': 'transformer_blocks.12.attn.to_out.bias', - '^transformer_blocks.13.attn.to_out.0.weight$': 'transformer_blocks.13.attn.to_out.weight', - '^transformer_blocks.13.attn.to_out.0.bias$': 'transformer_blocks.13.attn.to_out.bias', - '^transformer_blocks.14.attn.to_out.0.weight$': 'transformer_blocks.14.attn.to_out.weight', - '^transformer_blocks.14.attn.to_out.0.bias$': 'transformer_blocks.14.attn.to_out.bias', - '^transformer_blocks.15.attn.to_out.0.weight$': 'transformer_blocks.15.attn.to_out.weight', - '^transformer_blocks.15.attn.to_out.0.bias$': 'transformer_blocks.15.attn.to_out.bias', - '^transformer_blocks.16.attn.to_out.0.weight$': 'transformer_blocks.16.attn.to_out.weight', - '^transformer_blocks.16.attn.to_out.0.bias$': 'transformer_blocks.16.attn.to_out.bias', - '^transformer_blocks.17.attn.to_out.0.weight$': 'transformer_blocks.17.attn.to_out.weight', - '^transformer_blocks.17.attn.to_out.0.bias$': 'transformer_blocks.17.attn.to_out.bias', - '^transformer_blocks.18.attn.to_out.0.weight$': 'transformer_blocks.18.attn.to_out.weight', - '^transformer_blocks.18.attn.to_out.0.bias$': 'transformer_blocks.18.attn.to_out.bias', - '^transformer_blocks.19.attn.to_out.0.weight$': 'transformer_blocks.19.attn.to_out.weight', - '^transformer_blocks.19.attn.to_out.0.bias$': 'transformer_blocks.19.attn.to_out.bias', - '^transformer_blocks.20.attn.to_out.0.weight$': 'transformer_blocks.20.attn.to_out.weight', - '^transformer_blocks.20.attn.to_out.0.bias$': 'transformer_blocks.20.attn.to_out.bias', - '^transformer_blocks.21.attn.to_out.0.weight$': 'transformer_blocks.21.attn.to_out.weight', - '^transformer_blocks.21.attn.to_out.0.bias$': 'transformer_blocks.21.attn.to_out.bias', - '^transformer_blocks.0.ff.ff.0.0.weight$': 'transformer_blocks.0.ff.project_in.weight', - '^transformer_blocks.0.ff.ff.0.0.bias$': 'transformer_blocks.0.ff.project_in.bias', - '^transformer_blocks.0.ff.ff.2.weight$': 'transformer_blocks.0.ff.ff.weight', - '^transformer_blocks.0.ff.ff.2.bias$': 'transformer_blocks.0.ff.ff.bias', - '^transformer_blocks.1.ff.ff.0.0.weight$': 'transformer_blocks.1.ff.project_in.weight', - '^transformer_blocks.1.ff.ff.0.0.bias$': 'transformer_blocks.1.ff.project_in.bias', - '^transformer_blocks.1.ff.ff.2.weight$': 'transformer_blocks.1.ff.ff.weight', - '^transformer_blocks.1.ff.ff.2.bias$': 'transformer_blocks.1.ff.ff.bias', - '^transformer_blocks.2.ff.ff.0.0.weight$': 'transformer_blocks.2.ff.project_in.weight', - '^transformer_blocks.2.ff.ff.0.0.bias$': 'transformer_blocks.2.ff.project_in.bias', - '^transformer_blocks.2.ff.ff.2.weight$': 'transformer_blocks.2.ff.ff.weight', - '^transformer_blocks.2.ff.ff.2.bias$': 'transformer_blocks.2.ff.ff.bias', - '^transformer_blocks.3.ff.ff.0.0.weight$': 'transformer_blocks.3.ff.project_in.weight', - '^transformer_blocks.3.ff.ff.0.0.bias$': 'transformer_blocks.3.ff.project_in.bias', - '^transformer_blocks.3.ff.ff.2.weight$': 'transformer_blocks.3.ff.ff.weight', - '^transformer_blocks.3.ff.ff.2.bias$': 'transformer_blocks.3.ff.ff.bias', - '^transformer_blocks.4.ff.ff.0.0.weight$': 'transformer_blocks.4.ff.project_in.weight', - '^transformer_blocks.4.ff.ff.0.0.bias$': 'transformer_blocks.4.ff.project_in.bias', - '^transformer_blocks.4.ff.ff.2.weight$': 'transformer_blocks.4.ff.ff.weight', - '^transformer_blocks.4.ff.ff.2.bias$': 'transformer_blocks.4.ff.ff.bias', - '^transformer_blocks.5.ff.ff.0.0.weight$': 'transformer_blocks.5.ff.project_in.weight', - '^transformer_blocks.5.ff.ff.0.0.bias$': 'transformer_blocks.5.ff.project_in.bias', - '^transformer_blocks.5.ff.ff.2.weight$': 'transformer_blocks.5.ff.ff.weight', - '^transformer_blocks.5.ff.ff.2.bias$': 'transformer_blocks.5.ff.ff.bias', - '^transformer_blocks.6.ff.ff.0.0.weight$': 'transformer_blocks.6.ff.project_in.weight', - '^transformer_blocks.6.ff.ff.0.0.bias$': 'transformer_blocks.6.ff.project_in.bias', - '^transformer_blocks.6.ff.ff.2.weight$': 'transformer_blocks.6.ff.ff.weight', - '^transformer_blocks.6.ff.ff.2.bias$': 'transformer_blocks.6.ff.ff.bias', - '^transformer_blocks.7.ff.ff.0.0.weight$': 'transformer_blocks.7.ff.project_in.weight', - '^transformer_blocks.7.ff.ff.0.0.bias$': 'transformer_blocks.7.ff.project_in.bias', - '^transformer_blocks.7.ff.ff.2.weight$': 'transformer_blocks.7.ff.ff.weight', - '^transformer_blocks.7.ff.ff.2.bias$': 'transformer_blocks.7.ff.ff.bias', - '^transformer_blocks.8.ff.ff.0.0.weight$': 'transformer_blocks.8.ff.project_in.weight', - '^transformer_blocks.8.ff.ff.0.0.bias$': 'transformer_blocks.8.ff.project_in.bias', - '^transformer_blocks.8.ff.ff.2.weight$': 'transformer_blocks.8.ff.ff.weight', - '^transformer_blocks.8.ff.ff.2.bias$': 'transformer_blocks.8.ff.ff.bias', - '^transformer_blocks.9.ff.ff.0.0.weight$': 'transformer_blocks.9.ff.project_in.weight', - '^transformer_blocks.9.ff.ff.0.0.bias$': 'transformer_blocks.9.ff.project_in.bias', - '^transformer_blocks.9.ff.ff.2.weight$': 'transformer_blocks.9.ff.ff.weight', - '^transformer_blocks.9.ff.ff.2.bias$': 'transformer_blocks.9.ff.ff.bias', - '^transformer_blocks.10.ff.ff.0.0.weight$': 'transformer_blocks.10.ff.project_in.weight', - '^transformer_blocks.10.ff.ff.0.0.bias$': 'transformer_blocks.10.ff.project_in.bias', - '^transformer_blocks.10.ff.ff.2.weight$': 'transformer_blocks.10.ff.ff.weight', - '^transformer_blocks.10.ff.ff.2.bias$': 'transformer_blocks.10.ff.ff.bias', - '^transformer_blocks.11.ff.ff.0.0.weight$': 'transformer_blocks.11.ff.project_in.weight', - '^transformer_blocks.11.ff.ff.0.0.bias$': 'transformer_blocks.11.ff.project_in.bias', - '^transformer_blocks.11.ff.ff.2.weight$': 'transformer_blocks.11.ff.ff.weight', - '^transformer_blocks.11.ff.ff.2.bias$': 'transformer_blocks.11.ff.ff.bias', - '^transformer_blocks.12.ff.ff.0.0.weight$': 'transformer_blocks.12.ff.project_in.weight', - '^transformer_blocks.12.ff.ff.0.0.bias$': 'transformer_blocks.12.ff.project_in.bias', - '^transformer_blocks.12.ff.ff.2.weight$': 'transformer_blocks.12.ff.ff.weight', - '^transformer_blocks.12.ff.ff.2.bias$': 'transformer_blocks.12.ff.ff.bias', - '^transformer_blocks.13.ff.ff.0.0.weight$': 'transformer_blocks.13.ff.project_in.weight', - '^transformer_blocks.13.ff.ff.0.0.bias$': 'transformer_blocks.13.ff.project_in.bias', - '^transformer_blocks.13.ff.ff.2.weight$': 'transformer_blocks.13.ff.ff.weight', - '^transformer_blocks.13.ff.ff.2.bias$': 'transformer_blocks.13.ff.ff.bias', - '^transformer_blocks.14.ff.ff.0.0.weight$': 'transformer_blocks.14.ff.project_in.weight', - '^transformer_blocks.14.ff.ff.0.0.bias$': 'transformer_blocks.14.ff.project_in.bias', - '^transformer_blocks.14.ff.ff.2.weight$': 'transformer_blocks.14.ff.ff.weight', - '^transformer_blocks.14.ff.ff.2.bias$': 'transformer_blocks.14.ff.ff.bias', - '^transformer_blocks.15.ff.ff.0.0.weight$': 'transformer_blocks.15.ff.project_in.weight', - '^transformer_blocks.15.ff.ff.0.0.bias$': 'transformer_blocks.15.ff.project_in.bias', - '^transformer_blocks.15.ff.ff.2.weight$': 'transformer_blocks.15.ff.ff.weight', - '^transformer_blocks.15.ff.ff.2.bias$': 'transformer_blocks.15.ff.ff.bias', - '^transformer_blocks.16.ff.ff.0.0.weight$': 'transformer_blocks.16.ff.project_in.weight', - '^transformer_blocks.16.ff.ff.0.0.bias$': 'transformer_blocks.16.ff.project_in.bias', - '^transformer_blocks.16.ff.ff.2.weight$': 'transformer_blocks.16.ff.ff.weight', - '^transformer_blocks.16.ff.ff.2.bias$': 'transformer_blocks.16.ff.ff.bias', - '^transformer_blocks.17.ff.ff.0.0.weight$': 'transformer_blocks.17.ff.project_in.weight', - '^transformer_blocks.17.ff.ff.0.0.bias$': 'transformer_blocks.17.ff.project_in.bias', - '^transformer_blocks.17.ff.ff.2.weight$': 'transformer_blocks.17.ff.ff.weight', - '^transformer_blocks.17.ff.ff.2.bias$': 'transformer_blocks.17.ff.ff.bias', - '^transformer_blocks.18.ff.ff.0.0.weight$': 'transformer_blocks.18.ff.project_in.weight', - '^transformer_blocks.18.ff.ff.0.0.bias$': 'transformer_blocks.18.ff.project_in.bias', - '^transformer_blocks.18.ff.ff.2.weight$': 'transformer_blocks.18.ff.ff.weight', - '^transformer_blocks.18.ff.ff.2.bias$': 'transformer_blocks.18.ff.ff.bias', - '^transformer_blocks.19.ff.ff.0.0.weight$': 'transformer_blocks.19.ff.project_in.weight', - '^transformer_blocks.19.ff.ff.0.0.bias$': 'transformer_blocks.19.ff.project_in.bias', - '^transformer_blocks.19.ff.ff.2.weight$': 'transformer_blocks.19.ff.ff.weight', - '^transformer_blocks.19.ff.ff.2.bias$': 'transformer_blocks.19.ff.ff.bias', - '^transformer_blocks.20.ff.ff.0.0.weight$': 'transformer_blocks.20.ff.project_in.weight', - '^transformer_blocks.20.ff.ff.0.0.bias$': 'transformer_blocks.20.ff.project_in.bias', - '^transformer_blocks.20.ff.ff.2.weight$': 'transformer_blocks.20.ff.ff.weight', - '^transformer_blocks.20.ff.ff.2.bias$': 'transformer_blocks.20.ff.ff.bias', - '^transformer_blocks.21.ff.ff.0.0.weight$': 'transformer_blocks.21.ff.project_in.weight', - '^transformer_blocks.21.ff.ff.0.0.bias$': 'transformer_blocks.21.ff.project_in.bias', - '^transformer_blocks.21.ff.ff.2.weight$': 'transformer_blocks.21.ff.ff.weight', - '^transformer_blocks.21.ff.ff.2.bias$': 'transformer_blocks.21.ff.ff.bias', + "^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight", + "^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias", + "^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight", + "^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias", + "^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight", + "^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias", + "^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight", + "^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias", + "^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight", + "^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias", + "^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight", + "^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias", + "^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight", + "^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias", + "^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight", + "^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias", + "^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight", + "^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias", + "^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight", + "^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias", + "^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight", + "^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias", + "^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight", + "^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias", + "^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight", + "^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias", + "^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight", + "^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias", + "^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight", + "^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias", + "^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight", + "^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias", + "^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight", + "^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias", + "^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight", + "^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias", + "^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight", + "^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias", + "^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight", + "^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias", + "^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight", + "^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias", + "^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight", + "^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias", + "^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight", + "^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias", + "^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight", + "^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias", + "^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight", + "^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias", + "^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight", + "^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias", + "^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight", + "^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias", + "^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight", + "^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias", + "^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight", + "^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias", + "^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight", + "^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias", + "^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight", + "^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias", + "^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight", + "^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias", + "^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight", + "^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias", + "^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight", + "^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias", + "^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight", + "^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias", + "^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight", + "^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias", + "^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight", + "^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias", + "^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight", + "^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias", + "^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight", + "^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias", + "^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight", + "^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias", + "^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight", + "^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias", + "^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight", + "^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias", + "^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight", + "^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias", + "^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight", + "^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias", + "^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight", + "^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias", + "^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight", + "^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias", + "^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight", + "^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias", + "^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight", + "^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias", + "^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight", + "^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias", + "^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight", + "^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias", + "^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight", + "^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias", + "^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight", + "^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias", + "^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight", + "^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias", + "^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight", + "^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias", + "^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight", + "^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias", + "^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight", + "^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias", + "^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight", + "^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias", + "^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight", + "^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias", + "^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight", + "^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias", + "^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight", + "^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias", + "^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight", + "^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias", + "^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight", + "^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias", + "^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight", + "^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias", + "^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight", + "^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias", + "^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight", + "^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias", + "^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight", + "^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias", + "^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight", + "^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias", + "^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight", + "^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias", + "^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight", + "^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias", + "^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight", + "^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias", } def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--model_name', - type=str, - default="F5TTS_Base", - choices=[ - "F5TTS_Base", - ]) # TODO: support F5TTS_v1_Base - parser.add_argument('--timm_ckpt', - type=str, - default="./ckpts/model_1200000.pt") - parser.add_argument('--output_dir', - type=str, - default='./tllm_checkpoint', - help='The path to save the TensorRT-LLM checkpoint') - parser.add_argument('--hidden_size', - type=int, - default=1024, - help='The hidden size of DiT') - parser.add_argument('--depth', - type=int, - default=22, - help='The number of DiTBlock layers') - parser.add_argument('--num_heads', - type=int, - default=16, - help='The number of heads of attention module') - parser.add_argument('--cfg_scale', type=float, default=4.0) - parser.add_argument('--tp_size', - type=int, - default=1, - help='N-way tensor parallelism size') - parser.add_argument('--cp_size', - type=int, - default=1, - help='Context parallelism size') - parser.add_argument('--pp_size', - type=int, - default=1, - help='N-way pipeline parallelism size') - parser.add_argument('--dtype', - type=str, - default='float16', - choices=['float32', 'bfloat16', 'float16']) - parser.add_argument('--fp8_linear', - action='store_true', - help='Whether use FP8 for linear layers') parser.add_argument( - '--workers', - type=int, - default=1, - help='The number of workers for converting checkpoint in parallel') + "--model_name", + type=str, + default="F5TTS_Base", + choices=[ + "F5TTS_Base", + ], + ) # TODO: support F5TTS_v1_Base + parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt") + parser.add_argument( + "--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint" + ) + parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT") + parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers") + parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module") + parser.add_argument("--cfg_scale", type=float, default=4.0) + parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size") + parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size") + parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size") + parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"]) + parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers") + parser.add_argument( + "--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel" + ) args = parser.parse_args() return args -def convert_timm_dit(args, mapping, dtype='float32'): - +def convert_timm_dit(args, mapping, dtype="float32"): weights = {} tik = time.time() torch_dtype = str_dtype_to_torch(dtype) tensor_parallel = mapping.tp_size model_params = dict(torch.load(args.timm_ckpt)) - model_params = {k: v for k, v in model_params['ema_model_state_dict'].items() if k.startswith('ema_model.transformer')} - prefix = 'ema_model.transformer.' - model_params = {key[len(prefix):] if key.startswith(prefix) else key: value for key, value in model_params.items()} + model_params = { + k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer") + } + prefix = "ema_model.transformer." + model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()} timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING @@ -248,8 +224,7 @@ def convert_timm_dit(args, mapping, dtype='float32'): weights = dict() for name, param in model_params.items(): - if name == 'input_embed.conv_pos_embed.conv1d.0.weight' or \ - name == 'input_embed.conv_pos_embed.conv1d.2.weight': + if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight": weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1) else: weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype) @@ -257,48 +232,40 @@ def convert_timm_dit(args, mapping, dtype='float32'): assert len(weights) == len(model_params) # new_prefix = 'f5_transformer.' - new_prefix = '' - weights = {new_prefix+key:value for key, value in weights.items()} + new_prefix = "" + weights = {new_prefix + key: value for key, value in weights.items()} import math + scale_factor = math.pow(64, -0.25) for k, v in weights.items(): - if re.match('^transformer_blocks.*.attn.to_k.weight$', k): + if re.match("^transformer_blocks.*.attn.to_k.weight$", k): weights[k] *= scale_factor - weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, - tensor_parallel, mapping.tp_rank) + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) - elif re.match('^transformer_blocks.*.attn.to_k.bias$', k): + elif re.match("^transformer_blocks.*.attn.to_k.bias$", k): weights[k] *= scale_factor - weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, - tensor_parallel, mapping.tp_rank) + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) - elif re.match('^transformer_blocks.*.attn.to_q.weight$', k): - weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, - tensor_parallel, mapping.tp_rank) + elif re.match("^transformer_blocks.*.attn.to_q.weight$", k): + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) weights[k] *= scale_factor - elif re.match('^transformer_blocks.*.attn.to_q.bias$', k): - weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, - tensor_parallel, mapping.tp_rank) + elif re.match("^transformer_blocks.*.attn.to_q.bias$", k): + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) weights[k] *= scale_factor - elif re.match('^transformer_blocks.*.attn.to_v.weight$', k): - weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, - tensor_parallel, mapping.tp_rank) + elif re.match("^transformer_blocks.*.attn.to_v.weight$", k): + weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) - elif re.match('^transformer_blocks.*.attn.to_v.bias$', k): - weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, - tensor_parallel, mapping.tp_rank) + elif re.match("^transformer_blocks.*.attn.to_v.bias$", k): + weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank) - elif re.match('^transformer_blocks.*.attn.to_out.weight$', k): - weights[k] = split_matrix_tp(v, - tensor_parallel, - mapping.tp_rank, - dim=1) + elif re.match("^transformer_blocks.*.attn.to_out.weight$", k): + weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1) tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Weights loaded. Total time: {t}') + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + print(f"Weights loaded. Total time: {t}") return weights @@ -306,34 +273,34 @@ def save_config(args): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) config = { - 'architecture': "F5TTS", - 'dtype': args.dtype, - 'hidden_size': 1024, - 'num_hidden_layers': 22, - 'num_attention_heads': 16, - 'dim_head': 64, - 'dropout': 0.1, - 'ff_mult': 2, - 'mel_dim': 100, - 'text_num_embeds': 256, - 'text_dim': 512, - 'conv_layers': 4, - 'long_skip_connection': False, - 'mapping': { - 'world_size': args.cp_size * args.tp_size * args.pp_size, - 'cp_size': args.cp_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - } + "architecture": "F5TTS", + "dtype": args.dtype, + "hidden_size": 1024, + "num_hidden_layers": 22, + "num_attention_heads": 16, + "dim_head": 64, + "dropout": 0.1, + "ff_mult": 2, + "mel_dim": 100, + "text_num_embeds": 256, + "text_dim": 512, + "conv_layers": 4, + "long_skip_connection": False, + "mapping": { + "world_size": args.cp_size * args.tp_size * args.pp_size, + "cp_size": args.cp_size, + "tp_size": args.tp_size, + "pp_size": args.pp_size, + }, } if args.fp8_linear: - config['quantization'] = { - 'quant_algo': "FP8", + config["quantization"] = { + "quant_algo": "FP8", # TODO: add support for exclude modules. # 'exclude_modules': "*final_layer*", } - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + with open(os.path.join(args.output_dir, "config.json"), "w") as f: json.dump(config, f, indent=4) @@ -341,16 +308,17 @@ def covert_and_save(args, rank): if rank == 0: save_config(args) - mapping = Mapping(world_size=args.cp_size * args.tp_size * args.pp_size, - rank=rank, - cp_size=args.cp_size, - tp_size=args.tp_size, - pp_size=args.pp_size) + mapping = Mapping( + world_size=args.cp_size * args.tp_size * args.pp_size, + rank=rank, + cp_size=args.cp_size, + tp_size=args.tp_size, + pp_size=args.pp_size, + ) weights = convert_timm_dit(args, mapping, dtype=args.dtype) - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors")) def execute(workers, func, args): @@ -367,9 +335,7 @@ def execute(workers, func, args): except Exception as e: traceback.print_exc() exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." + assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log." def main(): @@ -385,9 +351,9 @@ def main(): execute(args.workers, [covert_and_save] * world_size, args) tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Total time of converting checkpoints: {t}') + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + print(f"Total time of converting checkpoints: {t}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py index 02fcf74..d94f0d7 100644 --- a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +++ b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py @@ -19,12 +19,12 @@ from huggingface_hub import hf_hub_download from conv_stft import STFT from vocos import Vocos import argparse + opset_version = 17 + def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--vocoder", type=str, @@ -40,8 +40,8 @@ def get_args(): ) return parser.parse_args() + class ISTFTHead(nn.Module): - def __init__(self, n_fft: int, hop_length: int): super().__init__() self.out = None @@ -54,9 +54,10 @@ class ISTFTHead(nn.Module): mag = torch.clip(mag, max=1e2) real = mag * torch.cos(p) imag = mag * torch.sin(p) - audio = self.stft.inverse(input1=real, input2=imag, input_type='realimag') + audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag") return audio + class VocosVocoder(nn.Module): def __init__(self, vocos_vocoder): super(VocosVocoder, self).__init__() @@ -67,31 +68,30 @@ class VocosVocoder(nn.Module): istft_head_for_export = ISTFTHead(n_fft, hop_length) istft_head_for_export.out = istft_head_out self.vocos_vocoder.head = istft_head_for_export - + def forward(self, mel): waveform = self.vocos_vocoder.decode(mel) return waveform + def export_VocosVocoder(vocos_vocoder, output_path, verbose): vocos_vocoder = VocosVocoder(vocos_vocoder).cuda() vocos_vocoder.eval() - + dummy_batch_size = 8 dummy_input_length = 500 - + dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda() with torch.no_grad(): - dummy_waveform = vocos_vocoder( - mel=dummy_mel - ) + dummy_waveform = vocos_vocoder(mel=dummy_mel) print(dummy_waveform.shape) - dummy_input = (dummy_mel) + dummy_input = dummy_mel torch.onnx.export( - vocos_vocoder, - dummy_input, + vocos_vocoder, + dummy_input, output_path, opset_version=opset_version, do_constant_folding=True, @@ -101,10 +101,12 @@ def export_VocosVocoder(vocos_vocoder, output_path, verbose): "mel": {0: "batch_size", 2: "input_length"}, "waveform": {0: "batch_size", 1: "output_length"}, }, - verbose=verbose) - + verbose=verbose, + ) + print("Exported to {}".format(output_path)) + def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None): if vocoder_name == "vocos": # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) @@ -122,23 +124,14 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cp vocoder.load_state_dict(state_dict) vocoder = vocoder.eval().to(device) elif vocoder_name == "bigvgan": - try: - from third_party.BigVGAN import bigvgan - except ImportError: - print("You need to follow the README to init submodule and change the BigVGAN source code.") - if is_local: - """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main""" - vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) - else: - local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir) - vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False) - + raise NotImplementedError("BigVGAN is not supported yet") vocoder.remove_weight_norm() vocoder = vocoder.eval().to(device) return vocoder + if __name__ == "__main__": args = get_args() vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None) if args.vocoder == "vocos": - export_VocosVocoder(vocoder, args.output_path, verbose=False) \ No newline at end of file + export_VocosVocoder(vocoder, args.output_path, verbose=False) diff --git a/src/f5_tts/runtime/triton_trtllm/fill_template.py b/src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py similarity index 65% rename from src/f5_tts/runtime/triton_trtllm/fill_template.py rename to src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py index 584a9f4..105cfac 100644 --- a/src/f5_tts/runtime/triton_trtllm/fill_template.py +++ b/src/f5_tts/runtime/triton_trtllm/scripts/fill_template.py @@ -27,16 +27,10 @@ if __name__ == "__main__": parser.add_argument("file_path", help="path of the .pbtxt to modify") parser.add_argument( "substitutions", - help= - "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...", ) - parser.add_argument("--in_place", - "-i", - action="store_true", - help="do the operation in-place") - parser.add_argument("--participant_ids", - help="Participant IDs for the model", - default="") + parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place") + parser.add_argument("--participant_ids", help="Participant IDs for the model", default="") args = parser.parse_args() main(**vars(args)) From 4681a1c177c1a64f203c77011da64574d7ef1bc7 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Apr 2025 02:35:26 +0000 Subject: [PATCH 3/7] remove annotation --- src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py index e0b830b..b89ca5c 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -1,12 +1,3 @@ -""" -ein notation: -b - batch -n - sequence -nt - text sequence -nw - raw wave length -d - dimension -""" - from __future__ import annotations import sys import os From ae51cc3d34aaea3666d7ac24eca4fbc6d9c0db84 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Apr 2025 04:25:43 +0000 Subject: [PATCH 4/7] fix bug --- src/f5_tts/runtime/triton_trtllm/README.md | 6 +++--- src/f5_tts/runtime/triton_trtllm/client_grpc.py | 10 +++++++--- .../model_repo_f5_tts/f5_tts/1/model.py | 2 +- .../model_repo_f5_tts/f5_tts/config.pbtxt | 4 ++-- .../runtime/triton_trtllm/patch/f5tts/model.py | 8 +++----- .../runtime/triton_trtllm/patch/f5tts/modules.py | 2 +- src/f5_tts/runtime/triton_trtllm/run.sh | 16 +++++++++------- 7 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/f5_tts/runtime/triton_trtllm/README.md b/src/f5_tts/runtime/triton_trtllm/README.md index ca69e5f..777f6a8 100644 --- a/src/f5_tts/runtime/triton_trtllm/README.md +++ b/src/f5_tts/runtime/triton_trtllm/README.md @@ -35,9 +35,9 @@ python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_t ### Benchmark Results Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs. -| Model | Note | Concurrency | Avg Latency | RTF | -|-------|-----------|-----------------------|---------|--| -| F5-TTS Base (Vocos) | [Code Commit](https://github.com/yuekaizhang/sherpa/tree/329ab3c573252e835844bea38505c6b43e994cf4/triton/f5_tts) | 1 | 253 ms | 0.0394| +| Model | Concurrency | Avg Latency | RTF | +|-------|-------------|-----------------|--| +| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394| ### Credits 1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm) \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/client_grpc.py b/src/f5_tts/runtime/triton_trtllm/client_grpc.py index 0d3b154..c4e2c43 100644 --- a/src/f5_tts/runtime/triton_trtllm/client_grpc.py +++ b/src/f5_tts/runtime/triton_trtllm/client_grpc.py @@ -245,6 +245,7 @@ async def send( model_name: str, padding_duration: int = None, audio_save_dir: str = "./", + save_sample_rate: int = 16000, ): total_duration = 0.0 latency_data = [] @@ -267,7 +268,9 @@ async def send( samples = np.zeros( ( 1, - padding_duration * sample_rate * ((int(duration) // padding_duration) + 1), + padding_duration + * sample_rate + * ((int(estimated_target_duration + duration) // padding_duration) + 1), ), dtype=np.float32, ) @@ -306,7 +309,7 @@ async def send( end = time.time() - start audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav") - sf.write(audio_save_path, audio, 16000, "PCM_16") + sf.write(audio_save_path, audio, save_sample_rate, "PCM_16") latency_data.append((end, estimated_target_duration)) total_duration += estimated_target_duration @@ -413,7 +416,8 @@ async def main(): log_interval=args.log_interval, model_name=args.model_name, audio_save_dir=args.log_dir, - padding_duration=1.0, + padding_duration=1, + save_sample_rate=24000 if args.model_name == "f5_tts" else 16000, ) ) tasks.append(task) diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py index a0ca9d3..9265886 100644 --- a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py @@ -158,7 +158,7 @@ class TritonPythonModel: return mel.transpose(1, 2) def forward_vocoder(self, mel): - mel = mel.to(torch.float32).contiguous() + mel = mel.to(torch.float32).contiguous().cpu() input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel)) inference_request = pb_utils.InferenceRequest( diff --git a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt index 171211e..4663f7c 100644 --- a/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt +++ b/src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/config.pbtxt @@ -14,9 +14,9 @@ name: "f5_tts" backend: "python" -max_batch_size: 1 +max_batch_size: 4 dynamic_batching { - max_queue_delay_microseconds: 1 + max_queue_delay_microseconds: 1000 } parameters [ { diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py index b89ca5c..26c8bc9 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -30,8 +30,7 @@ class InputEmbedding(Module): self.proj = Linear(mel_dim * 2 + text_dim, out_dim) self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim) - def forward(self, x, cond, drop_audio_cond=False): - # if drop_audio_cond: # cfg for cond audio + def forward(self, x, cond): x = self.proj(concat([x, cond], dim=-1)) return self.conv_pos_embed(x) + x @@ -41,9 +40,8 @@ class F5TTS(PretrainedModel): super().__init__(config) self.dtype = str_dtype_to_trt(config.dtype) - self.time_embed = TimestepEmbedding(config.hidden_size) # √ - text_dim = config.mel_dim if config.text_dim is None else config.text_dim - self.input_embed = InputEmbedding(config.mel_dim, text_dim, config.hidden_size) + self.time_embed = TimestepEmbedding(config.hidden_size) + self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size) self.dim = config.hidden_size self.depth = config.num_hidden_layers diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py index 5bfd5a0..a0051b4 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py @@ -93,7 +93,7 @@ class ConvPositionEmbedding(Module): self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2) self.mish = Mish() - def forward(self, x, mask): # noqa: F722 + def forward(self, x, mask=None): # noqa: F722 if default_net().plugin_config.remove_input_padding: x = unsqueeze(x, 0) x = permute(x, [0, 2, 1]) diff --git a/src/f5_tts/runtime/triton_trtllm/run.sh b/src/f5_tts/runtime/triton_trtllm/run.sh index bf5526d..270c4f5 100644 --- a/src/f5_tts/runtime/triton_trtllm/run.sh +++ b/src/f5_tts/runtime/triton_trtllm/run.sh @@ -14,23 +14,22 @@ F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine -num_task=2 -log_dir=./log_concurrent_tasks_${num_task} vocoder_trt_engine_path=vocos_vocoder.plan model_repo=./model_repo if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - echo "Copying f5 tts trtllm files" - python_package_path=/usr/local/lib/python3.12/dist-packages - cp -r patch/* $python_package_path/tensorrt_llm/models + echo "Downloading f5 tts from huggingface" + huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH + fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - echo "Downloading f5 tts from huggingface" - huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH + echo "Converting checkpoint" python3 ./scripts/convert_checkpoint.py \ --timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \ --output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model + python_package_path=/usr/local/lib/python3.12/dist-packages + cp -r patch/* $python_package_path/tensorrt_llm/models trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \ --max_batch_size 8 \ --output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable @@ -58,5 +57,8 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then echo "Testing triton server" + num_task=1 + log_dir=./log_concurrent_tasks_${num_task} + rm -r $log_dir python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir fi \ No newline at end of file From eca56943ec1c524f4813e2df955479029029a181 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Apr 2025 04:31:33 +0000 Subject: [PATCH 5/7] fix docker compose issue --- src/f5_tts/runtime/triton_trtllm/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/f5_tts/runtime/triton_trtllm/docker-compose.yml b/src/f5_tts/runtime/triton_trtllm/docker-compose.yml index b08bd08..1519591 100644 --- a/src/f5_tts/runtime/triton_trtllm/docker-compose.yml +++ b/src/f5_tts/runtime/triton_trtllm/docker-compose.yml @@ -17,4 +17,4 @@ services: device_ids: ['0'] capabilities: [gpu] command: > - /bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL" + /bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL" From 94018429309cee4e66c799c218b5032ae4e08b42 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Apr 2025 05:14:03 +0000 Subject: [PATCH 6/7] add http client --- src/f5_tts/runtime/triton_trtllm/README.md | 5 +- .../runtime/triton_trtllm/client_http.py | 142 ++++++++++++++++++ src/f5_tts/runtime/triton_trtllm/run.sh | 8 + 3 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 src/f5_tts/runtime/triton_trtllm/client_http.py diff --git a/src/f5_tts/runtime/triton_trtllm/README.md b/src/f5_tts/runtime/triton_trtllm/README.md index 777f6a8..a25c9be 100644 --- a/src/f5_tts/runtime/triton_trtllm/README.md +++ b/src/f5_tts/runtime/triton_trtllm/README.md @@ -25,7 +25,10 @@ Inside docker container, we would follow the official guide of TensorRT-LLM to b ```sh bash run.sh 0 4 F5TTS_Base ``` - +### HTTP Client +```sh +python3 client_http.py +``` ### Benchmark using Dataset ```sh num_task=2 diff --git a/src/f5_tts/runtime/triton_trtllm/client_http.py b/src/f5_tts/runtime/triton_trtllm/client_http.py new file mode 100644 index 0000000..87b83d5 --- /dev/null +++ b/src/f5_tts/runtime/triton_trtllm/client_http.py @@ -0,0 +1,142 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import requests +import soundfile as sf +import numpy as np +import argparse + + +def get_args(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--server-url", + type=str, + default="localhost:8000", + help="Address of the server", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default="../../infer/examples/basic/basic_ref_en.wav", + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="Some call me nature, others call me mother nature.", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.", + help="", + ) + + parser.add_argument( + "--model-name", + type=str, + default="f5_tts", + choices=["f5_tts", "spark_tts"], + help="triton model_repo module name to request", + ) + + parser.add_argument( + "--output-audio", + type=str, + default="output.wav", + help="Path to save the output audio", + ) + return parser.parse_args() + + +def prepare_request( + samples, + reference_text, + target_text, + sample_rate=16000, + audio_save_dir: str = "./", +): + assert len(samples.shape) == 1, "samples should be 1D" + lengths = np.array([[len(samples)]], dtype=np.int32) + samples = samples.reshape(1, -1).astype(np.float32) + + data = { + "inputs": [ + {"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()}, + { + "name": "reference_wav_len", + "shape": lengths.shape, + "datatype": "INT32", + "data": lengths.tolist(), + }, + {"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]}, + {"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]}, + ] + } + + return data + + +def load_audio(wav_path, target_sample_rate=16000): + assert target_sample_rate == 16000, "hard coding in server" + if isinstance(wav_path, dict): + samples = wav_path["array"] + sample_rate = wav_path["sampling_rate"] + else: + samples, sample_rate = sf.read(wav_path) + if sample_rate != target_sample_rate: + from scipy.signal import resample + + num_samples = int(len(samples) * (target_sample_rate / sample_rate)) + samples = resample(samples, num_samples) + return samples, target_sample_rate + + +if __name__ == "__main__": + args = get_args() + server_url = args.server_url + if not server_url.startswith(("http://", "https://")): + server_url = f"http://{server_url}" + + url = f"{server_url}/v2/models/{args.model_name}/infer" + samples, sr = load_audio(args.reference_audio) + assert sr == 16000, "sample rate hardcoded in server" + + samples = np.array(samples, dtype=np.float32) + data = prepare_request(samples, args.reference_text, args.target_text) + + rsp = requests.post( + url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"} + ) + result = rsp.json() + audio = result["outputs"][0]["data"] + audio = np.array(audio, dtype=np.float32) + sf.write(args.output_audio, audio, 24000, "PCM_16") diff --git a/src/f5_tts/runtime/triton_trtllm/run.sh b/src/f5_tts/runtime/triton_trtllm/run.sh index 270c4f5..4cc6420 100644 --- a/src/f5_tts/runtime/triton_trtllm/run.sh +++ b/src/f5_tts/runtime/triton_trtllm/run.sh @@ -61,4 +61,12 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log_dir=./log_concurrent_tasks_${num_task} rm -r $log_dir python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + echo "Testing http client" + audio=../../infer/examples/basic/basic_ref_en.wav + reference_text="Some call me nature, others call me mother nature." + target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." + python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" fi \ No newline at end of file From 2428d01a566bf3fe544c89cd9f3088c4873f8de3 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Apr 2025 05:25:29 +0000 Subject: [PATCH 7/7] remove empty lines --- src/f5_tts/runtime/triton_trtllm/Dockerfile.server | 8 +------- src/f5_tts/runtime/triton_trtllm/run.sh | 4 +--- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/f5_tts/runtime/triton_trtllm/Dockerfile.server b/src/f5_tts/runtime/triton_trtllm/Dockerfile.server index b73bfc4..861e266 100644 --- a/src/f5_tts/runtime/triton_trtllm/Dockerfile.server +++ b/src/f5_tts/runtime/triton_trtllm/Dockerfile.server @@ -1,9 +1,3 @@ FROM nvcr.io/nvidia/tritonserver:24.12-py3 RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos -WORKDIR /workspace - - - - - - +WORKDIR /workspace \ No newline at end of file diff --git a/src/f5_tts/runtime/triton_trtllm/run.sh b/src/f5_tts/runtime/triton_trtllm/run.sh index 4cc6420..88e0d68 100644 --- a/src/f5_tts/runtime/triton_trtllm/run.sh +++ b/src/f5_tts/runtime/triton_trtllm/run.sh @@ -1,5 +1,3 @@ - - stage=$1 stop_stage=$2 model=$3 # F5TTS_Base @@ -69,4 +67,4 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then reference_text="Some call me nature, others call me mother nature." target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" -fi \ No newline at end of file +fi