75 Commits
1.0.3 ... 1.1.4

Author SHA1 Message Date
SWivid
ac79d0ec1e v1.1.4 2025-05-05 04:05:25 +08:00
SWivid
dad398c0c1 Bug Fix #1015
Ensure custom config hashable in
2025-05-05 03:55:05 +08:00
SWivid
3d969bf78d minor fix for backward compatibility to gradio multistyle feature 2025-05-05 02:07:19 +08:00
SWivid
7c741c05f9 v1.1.3 better infer_gradio with cherrypick and cache support 2025-05-05 01:42:41 +08:00
SWivid
6d1a1e886a formatting, sorting 2025-05-05 01:41:28 +08:00
SWivid
b4efcd836a Add cache feature. Retrieve previous generated segments, default cache size 100 2025-05-05 01:37:22 +08:00
SWivid
818b868fab Update infer_gradio.py. Enable seed selecting for multistyle generation 2025-05-05 00:58:24 +08:00
SWivid
e6fee5e9ba Update infer_gradio.py
Use gr.Column to ensure backward compatibility

Remove height attr from gr.File to avoid possible malposition across versions
2025-05-04 09:25:41 +08:00
Yushen CHEN
2de214c122 Merge pull request #1014 from fakerybakery/fix-gradio-app-250503
Fix Gradio app
2025-05-04 09:14:32 +08:00
mrfakename
2999f642ce Row -> Column 2025-05-03 17:59:07 -07:00
mrfakename
03cff73343 remove equal_height requirement
Seems to break Gradio demo.
2025-05-03 17:57:41 -07:00
mrfakename
63c513840d fix gradio app 2025-05-03 17:56:21 -07:00
SWivid
3e6b6c0c0c update infer_gradio.py. rename for consistency 2025-05-04 08:04:00 +08:00
SWivid
f00ac4d06b fix infer-gradio chat feature etc. 2025-05-04 08:00:16 +08:00
Yushen CHEN
b0658bfd24 Merge pull request #1013 from petermg/main
Update infer_gradio.py
2025-05-04 03:33:22 +08:00
petermg
0cae51d646 Update infer_gradio.py
Modified formatting
2025-05-03 12:07:58 -07:00
petermg
95976041f2 Update infer_gradio.py
Added "randomize seed" checkmark and option to specify seed showing last seed used and can manually enter the desired seed number.
2025-05-03 11:38:50 -07:00
petermg
ba1bf74215 Update infer_gradio.py
Modified it so that when you upload a text file, the text of that file will show in the text input window. Also made the text file upload window show up BELOW the text input display window.
2025-05-03 11:22:07 -07:00
petermg
536c29ac57 Update infer_gradio.py
Modified the UI to accept txt files as inputs
2025-05-02 12:45:39 -07:00
SWivid
c4c61b0110 v1.1.2 several updates
add data prepare script recipe for emilia-yodas; fix speech_edit.py; fix tensorrt-llm server code-switch
2025-05-02 03:13:33 +08:00
SWivid
5f80fec160 fix speech_edit.py 2025-04-26 02:10:39 +08:00
Yushen CHEN
178cb8afe6 Merge pull request #986 from fakerybakery/emilia-v2
Add processing script for new Emilia dataset format
2025-04-19 14:16:37 +08:00
mrfakename
761c7ed938 Add processing script for new Emilia dataset format 2025-04-18 20:56:31 -07:00
Yushen CHEN
13fd6f8e07 Merge pull request #971 from tbxark-fork/main
chore: Update the model checkpoint path to use the cache path.
2025-04-14 15:54:50 +08:00
tbxark
b2284b6cff chore: Update the model checkpoint path to use the cache path. 2025-04-14 11:28:48 +08:00
SWivid
4b4359bc39 finetune_gradio not to use fp16 by default for mps device 2025-04-03 22:33:21 +08:00
SWivid
fe5c562212 v1.1.1 add benchmark and trtllm offline code 2025-04-03 18:33:48 +08:00
Yushen CHEN
2374f8ec39 Merge pull request #948 from yuekaizhang/trtllm_benchmark
[TRT-LLM] add benchmark code
2025-04-03 18:27:21 +08:00
Yuekai Zhang
f4f10bff6c fix comment 2025-04-03 02:44:59 -07:00
Yuekai Zhang
9771ec6a3a add benchmark code 2025-04-03 02:42:40 -07:00
SWivid
4b3cd13382 Update README.md 2025-04-03 15:04:42 +08:00
SWivid
25b3291715 Update README.md 2025-04-03 14:41:52 +08:00
SWivid
16c480a61d v1.1.0 Support GPU Depolyment with Triton and TensorRT-LLM #944 2025-04-03 14:37:58 +08:00
SWivid
d9dfbe47cc Update README.md 2025-04-03 14:36:22 +08:00
Yushen CHEN
d1f6c95fe8 Merge pull request #944 from yuekaizhang/triton
Support GPU Depolyment Solution with Triton and TensorRT-LLM
2025-04-03 13:42:37 +08:00
root
2428d01a56 remove empty lines 2025-04-03 05:25:29 +00:00
root
9401842930 add http client 2025-04-03 05:14:03 +00:00
root
eca56943ec fix docker compose issue 2025-04-03 04:31:33 +00:00
root
ae51cc3d34 fix bug 2025-04-03 04:25:43 +00:00
root
4681a1c177 remove annotation 2025-04-03 02:35:26 +00:00
root
5b178397e0 remove unused codes 2025-04-03 02:34:28 +00:00
Yuekai Zhang
2724f9f101 add Nvidia Triton TensorRT-LLM solution 2025-04-02 19:04:45 -07:00
SWivid
7258b09529 v1.0.10 support custom chat model 2025-03-31 21:15:26 +08:00
SWivid
784e3862b4 add microsoft/Phi-4-mini-instruct to chat model list #937 2025-03-31 21:14:39 +08:00
SWivid
6f6968b034 formatting 2025-03-31 19:45:38 +08:00
maximechen
9bd2d13be1 Merge branch 'huanglizhuo-feat/support-custom-chat-model' 2025-03-31 19:22:08 +08:00
maximechen
b7c41af9cd reorganize and distinguish behavior from local and space 2025-03-31 19:11:52 +08:00
huanglizhuo
eaa7fd8a01 Reapply pre-commit hooks 2025-03-29 20:58:42 +09:00
Yushen CHEN
f34465d118 v1.0.9 several fixes 2025-03-28 23:12:13 +08:00
lizhuo
393993321d fix: use pydantic<=2.10.6 to address dependency conflict with gradio-app #930 2025-03-28 23:10:41 +08:00
lizhuo
29d3326bed update: JA latest HF path in SHARED.md #928
* fix: update japanese latest hf path
* update the huggingface url
2025-03-28 22:36:17 +08:00
Zhikang Niu
67e43dc0fb Merge pull request #926 from huanglizhuo/fix/shared-file-path
fix the SHARED.md file path
2025-03-28 17:14:54 +08:00
huanglizhuo
8469025b1c fix the shared.md file path 2025-03-28 17:52:08 +09:00
Zhikang Niu
5bd8cd7aed update: better save last & per ckpt logic #924
Co-authored-by: Yushen CHEN <45333109+SWivid@users.noreply.github.com>
2025-03-28 13:53:12 +08:00
SWivid
7236536f9a update utils_infer.py 2025-03-25 17:24:20 +08:00
SWivid
6b7f6eefdc fix typo in trainer.py with 4ae5347282 formatting #909 2025-03-25 16:17:03 +08:00
SWivid
b9156c0ad5 v1.0.8 fix a fatal bug with log_samples since 37eb3b50da 2025-03-25 07:49:19 +08:00
SWivid
3ad3211915 Update F5TTS_Small.yaml 2025-03-25 07:11:35 +08:00
Zhikang Niu
f6726a78cc Update F5TTS_Small.yaml 2025-03-23 22:27:02 +08:00
SWivid
1d0cf2b8ba add device option for infer-cli, patch-1 2025-03-22 17:35:16 +08:00
SWivid
1d82b7928e add device option for infer-cli 2025-03-22 17:30:23 +08:00
SWivid
4ae5347282 pre-commit update and formatting 2025-03-21 23:01:00 +08:00
SWivid
621559cbbe v1.0.7 2025-03-21 14:40:52 +08:00
SWivid
526b09eebd add no_zero_init v1 variant path to SHARED.md 2025-03-21 14:37:14 +08:00
SWivid
9afa80f204 add option in finetune gradio to save non-ema model weight 2025-03-21 13:36:11 +08:00
SWivid
c6b3189bbd v1.0.6 improves docker usage 2025-03-20 22:48:36 +08:00
Yushen CHEN
c87ce39515 Merge pull request #890 from MicahZoltu/patch-1
Improves documentation around docker usage.
2025-03-20 22:45:40 +08:00
Micah Zoltu
10ef27065b Improves documentation around docker usage. 2025-03-20 21:37:48 +08:00
SWivid
f374640f34 Merge branch 'main' of github.com:SWivid/F5-TTS 2025-03-20 13:54:52 +08:00
SWivid
d5f4c88aa4 update issue templates 2025-03-20 13:54:15 +08:00
Yushen CHEN
f968e13b6d Update README.md 2025-03-20 10:15:47 +08:00
SWivid
339b17fed3 update README.md for infer & train 2025-03-20 10:14:22 +08:00
SWivid
79302b694a update README.md for infer & train 2025-03-20 10:03:54 +08:00
SWivid
a1e88c2a9e v1.0.5 update finetune_gradio.py for clearer guidance 2025-03-17 21:50:50 +08:00
SWivid
1ab90505a4 v1.0.4 fix finetune_gradio.py vocab extend with .safetensors ckpt 2025-03-17 16:22:26 +08:00
66 changed files with 4762 additions and 533 deletions

View File

@@ -1,6 +1,6 @@
name: "Bug Report"
description: |
Please provide as much details to help address the issue, including logs and screenshots.
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- bug
body:
@@ -15,13 +15,13 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
validations:
required: true
- type: textarea
@@ -39,12 +39,12 @@ body:
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen.
placeholder: Describe in detail what you expected to happen.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened.
placeholder: Describe in detail what actually happened.
validations:
required: false

View File

@@ -15,7 +15,7 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and found not discussion yet.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:

View File

@@ -1,6 +1,6 @@
name: "Help Wanted"
description: |
Please provide as much details to help address the issue, including logs and screenshots.
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- help wanted
body:
@@ -15,36 +15,40 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
placeholder: |
e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
If training or finetuning related, provide detailed configuration including GPU info and training setup.
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: |
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
placeholder: |
1. Create a new conda environment.
2. Clone the repository and install as pip package.
3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen, e.g. output a generated audio
placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened, failure messages, etc.
placeholder: Describe what actually happened in detail, failure messages, etc.
validations:
required: false

View File

@@ -1,6 +1,6 @@
name: "Question"
description: |
Pure question or inquiry about the project, usage issue goes with "help wanted".
Research question or pure inquiry about the project, usage issue goes with "help wanted".
labels:
- question
body:
@@ -9,13 +9,13 @@ body:
label: Checks
description: "To help us grasp quickly, please confirm the following:"
options:
- label: This template is only for question, not feature requests or bug reports.
- label: This template is only for research question, not usage problems, feature requests or bug reports.
required: true
- label: I have thoroughly reviewed the project documentation and read the related paper(s).
required: true
- label: I have searched for existing issues, including closed ones, no similar questions.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:

2
.gitignore vendored
View File

@@ -7,8 +7,6 @@ ckpts/
wandb/
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@@ -1,14 +1,17 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.0
rev: v0.11.2
hooks:
# Run the linter.
- id: ruff
name: ruff linter
args: [--fix]
# Run the formatter.
- id: ruff-format
name: ruff formatter
- id: ruff
name: ruff sorter
args: [--select, I, --fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v5.0.0
hooks:
- id: check-yaml

View File

@@ -23,4 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
ENV SHELL=/bin/bash
VOLUME /root/.cache/huggingface/hub/
EXPOSE 7860
WORKDIR /workspace/F5-TTS

View File

@@ -100,13 +100,34 @@ conda activate f5-tts
# Build from Dockerfile
docker build -t f5tts:v1 .
# Or pull from GitHub Container Registry
docker pull ghcr.io/swivid/f5-tts:main
# Run from GitHub Container Registry
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
# Quickstart if you want to just run the web interface (not CLI)
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
```
### Runtime
Deployment solution with Triton and TensorRT-LLM.
#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
## Inference
- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer).
- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful.
### 1. Gradio App
Currently supported features:
@@ -173,11 +194,6 @@ f5-tts_infer-cli -c custom.toml
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
```
### 3. More instructions
- In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer).
- The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue.
## Training
@@ -200,7 +216,7 @@ Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
## Development
Use pre-commit to ensure code quality (will run linters and formatters automatically)
Use pre-commit to ensure code quality (will run linters and formatters automatically):
```bash
pip install pre-commit
@@ -213,7 +229,7 @@ When making a pull request, before each commit, run:
pre-commit run --all-files
```
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
## Acknowledgements
@@ -228,6 +244,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
- [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~
## Citation
If our work and codebase is useful for you, please cite as:

View File

@@ -1,12 +1,3 @@
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
```
ckpts/
F5TTS_v1_Base/
model_1250000.safetensors
F5TTS_Base/
model_1200000.safetensors
E2TTS_Base/
model_1200000.safetensors
```
Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.0.3"
version = "1.1.4"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -26,6 +26,7 @@ dependencies = [
"librosa",
"matplotlib",
"numpy<=1.26.4",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
"safetensors",

View File

@@ -6,5 +6,5 @@ target-version = "py310"
dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = true
force-single-line = false
lines-after-imports = 2

View File

@@ -5,18 +5,18 @@ from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
transcribe,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
)
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
from f5_tts.model.utils import seed_everything
@@ -33,7 +33,7 @@ class F5TTS:
hf_cache_dir=None,
):
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
@@ -119,7 +119,7 @@ class F5TTS:
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
wav, sr, spec = infer_process(
ref_file,

View File

@@ -10,7 +10,7 @@ datasets:
num_workers: 16
optim:
epochs: 11
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -49,4 +49,4 @@ ckpts:
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -4,6 +4,7 @@
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@@ -1,6 +1,7 @@
import os
import sys
sys.path.append(os.getcwd())
import argparse
@@ -10,6 +11,7 @@ from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from hydra.utils import get_class
from omegaconf import OmegaConf
from tqdm import tqdm
@@ -19,9 +21,10 @@ from f5_tts.eval.utils_eval import (
get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
@@ -65,7 +68,7 @@ def main():
no_ref_audio = False
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name
@@ -195,7 +198,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
if __name__ == "__main__":

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_librispeech_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_seed_tts_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))

View File

@@ -148,9 +148,9 @@ def get_inference_prompt(
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
assert min_tokens <= total_mel_len <= max_tokens, (
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
)
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt)

View File

@@ -4,16 +4,17 @@ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://h
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
To avoid possible inference failures, make sure you have seen through the following instructions.
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
- Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
## Gradio App
@@ -23,7 +24,7 @@ Currently supported features:
- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
- [Custom inference with more language support](SHARED.md)
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.

View File

@@ -44,6 +44,7 @@
```bash
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
```
@@ -136,11 +137,11 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
#### F5-TTS Base @ ja @ Jmica
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_25498980)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
```bash
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```

View File

@@ -10,24 +10,25 @@ import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
cross_fade_duration,
device,
fix_duration,
infer_process,
load_model,
load_vocoder,
mel_spec_type,
nfe_step,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
speed,
sway_sampling_coef,
target_rms,
)
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
parser = argparse.ArgumentParser(
@@ -162,6 +163,11 @@ parser.add_argument(
type=float,
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
parser.add_argument(
"--device",
type=str,
help="Specify the device to run on",
)
args = parser.parse_args()
@@ -202,6 +208,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
device = args.device or config.get("device", device)
# patches for pip pkg user
@@ -239,20 +246,23 @@ if vocoder_name == "vocos":
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
vocoder = load_vocoder(
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
)
# load TTS model
model_cfg = OmegaConf.load(
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
).model
model_cls = globals()[model_cfg.backbone]
)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base":
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
# override for previous models
if model == "F5TTS_Base":
@@ -269,7 +279,9 @@ if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
# inference process
@@ -325,6 +337,7 @@ def main():
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
generated_audio_segments.append(audio_segment)
@@ -332,7 +345,7 @@ def main():
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,
final_sample_rate,
)

View File

@@ -1,20 +1,24 @@
# ruff: noqa: E402
# Above allows ruff to ignore E402: module level import not at top of file
import gc
import json
import re
import tempfile
from collections import OrderedDict
from functools import lru_cache
from importlib.resources import files
import click
import gradio as gr
import numpy as np
import soundfile as sf
import torch
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
@@ -30,15 +34,15 @@ def gpu_decorator(func):
return func
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
)
from f5_tts.model import DiT, UNetT
DEFAULT_TTS_MODEL = "F5-TTS_v1"
@@ -76,6 +80,8 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
vocab_path = str(cached_path(vocab_path))
if model_cfg is None:
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
elif isinstance(model_cfg, str):
model_cfg = json.loads(model_cfg)
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
@@ -88,7 +94,7 @@ chat_tokenizer_state = None
@gpu_decorator
def generate_response(messages, model, tokenizer):
def chat_model_inference(messages, model, tokenizer):
"""Generate response using Qwen"""
text = tokenizer.apply_chat_template(
messages,
@@ -110,6 +116,17 @@ def generate_response(messages, model, tokenizer):
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
@gpu_decorator
def load_text_from_file(file):
if file:
with open(file, "r", encoding="utf-8") as f:
text = f.read().strip()
else:
text = ""
return gr.update(value=text)
@lru_cache(maxsize=100) # NOTE. need to ensure params of infer() hashable
@gpu_decorator
def infer(
ref_audio_orig,
@@ -117,6 +134,7 @@ def infer(
gen_text,
model,
remove_silence,
seed,
cross_fade_duration=0.15,
nfe_step=32,
speed=1,
@@ -126,8 +144,15 @@ def infer(
gr.Warning("Please provide reference audio.")
return gr.update(), gr.update(), ref_text
# Set inference seed
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
used_seed = seed
if not gen_text.strip():
gr.Warning("Please enter text to generate.")
gr.Warning("Please enter text to generate or upload a text file.")
return gr.update(), gr.update(), ref_text
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
@@ -140,7 +165,7 @@ def infer(
show_info("Loading E2-TTS model...")
E2TTS_ema_model = load_e2tts()
ema_model = E2TTS_ema_model
elif isinstance(model, list) and model[0] == "Custom":
elif isinstance(model, tuple) and model[0] == "Custom":
assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
global custom_ema_model, pre_custom_path
if pre_custom_path != model[1]:
@@ -175,7 +200,7 @@ def infer(
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
with gr.Blocks() as app_credits:
@@ -189,19 +214,38 @@ with gr.Blocks() as app_credits:
with gr.Blocks() as app_tts:
gr.Markdown("# Batched TTS")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
with gr.Row():
gen_text_input = gr.Textbox(
label="Text to Generate",
lines=10,
max_lines=40,
scale=4,
)
gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
ref_text_input = gr.Textbox(
label="Reference Text",
info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
lines=2,
)
remove_silence = gr.Checkbox(
label="Remove Silences",
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
value=False,
)
with gr.Row():
ref_text_input = gr.Textbox(
label="Reference Text",
info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
lines=2,
scale=4,
)
ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed",
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
value=True,
scale=3,
)
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
with gr.Column(scale=4):
remove_silence = gr.Checkbox(
label="Remove Silences",
info="If undesired long silence(s) produced, turn on to automatically detect and crop.",
value=False,
)
speed_slider = gr.Slider(
label="Speed",
minimum=0.3,
@@ -236,21 +280,39 @@ with gr.Blocks() as app_tts:
ref_text_input,
gen_text_input,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
):
audio_out, spectrogram_path, ref_text_out = infer(
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
ref_audio_input,
ref_text_input,
gen_text_input,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=cross_fade_duration_slider,
nfe_step=nfe_slider,
speed=speed_slider,
)
return audio_out, spectrogram_path, ref_text_out
return audio_out, spectrogram_path, ref_text_out, used_seed
gen_text_file.upload(
load_text_from_file,
inputs=[gen_text_file],
outputs=[gen_text_input],
)
ref_text_file.upload(
load_text_from_file,
inputs=[ref_text_file],
outputs=[ref_text_input],
)
generate_btn.click(
basic_tts,
@@ -259,35 +321,46 @@ with gr.Blocks() as app_tts:
ref_text_input,
gen_text_input,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
],
outputs=[audio_output, spectrogram_output, ref_text_input],
outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
)
def parse_speechtypes_text(gen_text):
# Pattern to find {speechtype}
pattern = r"\{(.*?)\}"
# Pattern to find {str} or {"name": str, "seed": int, "speed": float}
pattern = r"(\{.*?\})"
# Split the text by the pattern
tokens = re.split(pattern, gen_text)
segments = []
current_style = "Regular"
current_type_dict = {
"name": "Regular",
"seed": -1,
"speed": 1.0,
}
for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
segments.append({"style": current_style, "text": text})
current_type_dict["text"] = text
segments.append(current_type_dict)
else:
# This is style
style = tokens[i].strip()
current_style = style
# This is type
type_str = tokens[i].strip()
try: # if type dict
current_type_dict = json.loads(type_str)
except json.decoder.JSONDecodeError:
type_str = type_str[1:-1] # remove brace {}
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
return segments
@@ -298,44 +371,55 @@ with gr.Blocks() as app_multistyle:
"""
# Multiple Speech-Type Generation
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
"""
)
with gr.Row():
gr.Markdown(
"""
**Example Input:**
{Regular} Hello, I'd like to order a sandwich please.
{Surprised} What do you mean you're out of bread?
{Sad} I really wanted a sandwich though...
{Angry} You know what, darn you and your little shop!
{Whisper} I'll just go back home and cry now.
{Shouting} Why me?!
**Example Input:** <br>
{Regular} Hello, I'd like to order a sandwich please. <br>
{Surprised} What do you mean you're out of bread? <br>
{Sad} I really wanted a sandwich though... <br>
{Angry} You know what, darn you and your little shop! <br>
{Whisper} I'll just go back home and cry now. <br>
{Shouting} Why me?!
"""
)
gr.Markdown(
"""
**Example Input 2:**
{Speaker1_Happy} Hello, I'd like to order a sandwich please.
{Speaker2_Regular} Sorry, we're out of bread.
{Speaker1_Sad} I really wanted a sandwich though...
{Speaker2_Whisper} I'll give you the last one I was hiding.
**Example Input 2:** <br>
{"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
{"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
{"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
{"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
"""
)
gr.Markdown(
"Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
)
# Regular speech type (mandatory)
with gr.Row() as regular_row:
with gr.Column():
with gr.Row(variant="compact") as regular_row:
with gr.Column(scale=1, min_width=160):
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
regular_insert = gr.Button("Insert Label", variant="secondary")
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
with gr.Column(scale=3):
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
with gr.Column(scale=3):
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
with gr.Row():
regular_seed_slider = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
)
regular_speed_slider = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
# Regular speech type (max 100)
max_speech_types = 100
@@ -343,25 +427,55 @@ with gr.Blocks() as app_multistyle:
speech_type_names = [regular_name]
speech_type_audios = [regular_audio]
speech_type_ref_texts = [regular_ref_text]
speech_type_ref_text_files = [regular_ref_text_file]
speech_type_seeds = [regular_seed_slider]
speech_type_speeds = [regular_speed_slider]
speech_type_delete_btns = [None]
speech_type_insert_btns = [regular_insert]
# Additional speech types (99 more)
for i in range(max_speech_types - 1):
with gr.Row(visible=False) as row:
with gr.Column():
with gr.Row(variant="compact", visible=False) as row:
with gr.Column(scale=1, min_width=160):
name_input = gr.Textbox(label="Speech Type Name")
delete_btn = gr.Button("Delete Type", variant="secondary")
insert_btn = gr.Button("Insert Label", variant="secondary")
audio_input = gr.Audio(label="Reference Audio", type="filepath")
ref_text_input = gr.Textbox(label="Reference Text", lines=2)
delete_btn = gr.Button("Delete Type", variant="stop")
with gr.Column(scale=3):
audio_input = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column(scale=3):
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
with gr.Row():
seed_input = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
)
speed_input = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
speech_type_rows.append(row)
speech_type_names.append(name_input)
speech_type_audios.append(audio_input)
speech_type_ref_texts.append(ref_text_input)
speech_type_ref_text_files.append(ref_text_file_input)
speech_type_seeds.append(seed_input)
speech_type_speeds.append(speed_input)
speech_type_delete_btns.append(delete_btn)
speech_type_insert_btns.append(insert_btn)
# Global logic for all speech types
for i in range(max_speech_types):
speech_type_audios[i].clear(
lambda: [None, None],
None,
[speech_type_ref_texts[i], speech_type_ref_text_files[i]],
)
speech_type_ref_text_files[i].upload(
load_text_from_file,
inputs=[speech_type_ref_text_files[i]],
outputs=[speech_type_ref_texts[i]],
)
# Button to add speech type
add_speech_type_btn = gr.Button("Add Speech Type")
@@ -383,27 +497,44 @@ with gr.Blocks() as app_multistyle:
# Function to delete a speech type
def delete_speech_type_fn():
return gr.update(visible=False), None, None, None
return gr.update(visible=False), None, None, None, None
# Update delete button clicks
# Update delete button clicks and ref text file changes
for i in range(1, len(speech_type_delete_btns)):
speech_type_delete_btns[i].click(
delete_speech_type_fn,
outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
outputs=[
speech_type_rows[i],
speech_type_names[i],
speech_type_audios[i],
speech_type_ref_texts[i],
speech_type_ref_text_files[i],
],
)
# Text input for the prompt
gen_text_input_multistyle = gr.Textbox(
label="Text to Generate",
lines=10,
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
)
with gr.Row():
gen_text_input_multistyle = gr.Textbox(
label="Text to Generate",
lines=10,
max_lines=40,
scale=4,
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
)
gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
def make_insert_speech_type_fn(index):
def insert_speech_type_fn(current_text, speech_type_name):
def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
current_text = current_text or ""
speech_type_name = speech_type_name or "None"
updated_text = current_text + f"{{{speech_type_name}}} "
if not speech_type_name:
gr.Warning("Please enter speech type name before insert.")
return current_text
speech_type_dict = {
"name": speech_type_name,
"seed": speech_type_seed,
"speed": speech_type_speed,
}
updated_text = current_text + json.dumps(speech_type_dict) + " "
return updated_text
return insert_speech_type_fn
@@ -412,15 +543,24 @@ with gr.Blocks() as app_multistyle:
insert_fn = make_insert_speech_type_fn(i)
insert_btn.click(
insert_fn,
inputs=[gen_text_input_multistyle, speech_type_names[i]],
inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
outputs=gen_text_input_multistyle,
)
with gr.Accordion("Advanced Settings", open=False):
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
value=True,
)
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
with gr.Column():
show_cherrypick_multistyle = gr.Checkbox(
label="Show Cherry-pick Interface",
info="Turn on to show interface, picking seeds from previous generations.",
value=False,
)
with gr.Column():
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
info="Turn on to automatically detect and crop long silences.",
value=True,
)
# Generate button
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
@@ -428,6 +568,30 @@ with gr.Blocks() as app_multistyle:
# Output audio
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
# Used seed gallery
cherrypick_interface_multistyle = gr.Textbox(
label="Cherry-pick Interface",
lines=10,
max_lines=40,
show_copy_button=True,
interactive=False,
visible=False,
)
# Logic control to show/hide the cherrypick interface
show_cherrypick_multistyle.change(
lambda is_visible: gr.update(visible=is_visible),
show_cherrypick_multistyle,
cherrypick_interface_multistyle,
)
# Function to load text to generate from file
gen_text_file_multistyle.upload(
load_text_from_file,
inputs=[gen_text_file_multistyle],
outputs=[gen_text_input_multistyle],
)
@gpu_decorator
def generate_multistyle_speech(
gen_text,
@@ -455,41 +619,60 @@ with gr.Blocks() as app_multistyle:
# For each segment, generate speech
generated_audio_segments = []
current_style = "Regular"
current_type_name = "Regular"
inference_meta_data = ""
for segment in segments:
style = segment["style"]
name = segment["name"]
seed_input = segment["seed"]
speed = segment["speed"]
text = segment["text"]
if style in speech_types:
current_style = style
if name in speech_types:
current_type_name = name
else:
gr.Warning(f"Type {style} is not available, will use Regular as default.")
current_style = "Regular"
gr.Warning(f"Type {name} is not available, will use Regular as default.")
current_type_name = "Regular"
try:
ref_audio = speech_types[current_style]["audio"]
ref_audio = speech_types[current_type_name]["audio"]
except KeyError:
gr.Warning(f"Please provide reference audio for type {current_style}.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
ref_text = speech_types[current_style].get("ref_text", "")
gr.Warning(f"Please provide reference audio for type {current_type_name}.")
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
ref_text = speech_types[current_type_name].get("ref_text", "")
# Generate speech for this segment
audio_out, _, ref_text_out = infer(
ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
) # show_info=print no pull to top when generating
if seed_input == -1:
seed_input = np.random.randint(0, 2**31 - 1)
# Generate or retrieve speech for this segment
audio_out, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
text,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=0,
speed=speed,
show_info=print, # no pull to top when generating
)
sr, audio_data = audio_out
generated_audio_segments.append(audio_data)
speech_types[current_style]["ref_text"] = ref_text_out
speech_types[current_type_name]["ref_text"] = ref_text_out
inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
return (
[(sr, final_audio_data)]
+ [speech_types[name]["ref_text"] for name in speech_types]
+ [inference_meta_data]
)
else:
gr.Warning("No audio generated.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
generate_multistyle_btn.click(
generate_multistyle_speech,
@@ -502,7 +685,7 @@ with gr.Blocks() as app_multistyle:
+ [
remove_silence_multistyle,
],
outputs=[audio_output_multistyle] + speech_type_ref_texts,
outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
)
# Validation function to disable Generate button if speech types are missing
@@ -519,7 +702,7 @@ with gr.Blocks() as app_multistyle:
# Parse the gen_text to get the speech types used
segments = parse_speechtypes_text(gen_text)
speech_types_in_text = set(segment["style"] for segment in segments)
speech_types_in_text = set(segment["name"] for segment in segments)
# Check if all speech types in text are available
missing_speech_types = speech_types_in_text - speech_types_available
@@ -542,43 +725,58 @@ with gr.Blocks() as app_chat:
gr.Markdown(
"""
# Voice Chat
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript.
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript (via text or .txt file).
2. Load the chat model.
3. Record your message through your microphone.
3. Record your message through your microphone or type it.
4. The AI will respond using the reference voice.
"""
)
if not USING_SPACES:
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
chat_model_name_list = [
"Qwen/Qwen2.5-3B-Instruct",
"microsoft/Phi-4-mini-instruct",
]
chat_interface_container = gr.Column(visible=False)
@gpu_decorator
def load_chat_model(chat_model_name):
show_info = gr.Info
global chat_model_state, chat_tokenizer_state
if chat_model_state is not None:
chat_model_state = None
chat_tokenizer_state = None
gc.collect()
torch.cuda.empty_cache()
@gpu_decorator
def load_chat_model():
global chat_model_state, chat_tokenizer_state
if chat_model_state is None:
show_info = gr.Info
show_info("Loading chat model...")
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
show_info("Chat model loaded.")
show_info(f"Loading chat model: {chat_model_name}")
chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
show_info(f"Chat model {chat_model_name} loaded successfully!")
return gr.update(visible=False), gr.update(visible=True)
return gr.update(visible=False), gr.update(visible=True)
load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
if USING_SPACES:
load_chat_model(chat_model_name_list[0])
else:
chat_interface_container = gr.Column()
chat_model_name_input = gr.Dropdown(
choices=chat_model_name_list,
value=chat_model_name_list[0],
label="Chat Model Name",
info="Enter the name of a HuggingFace chat model",
allow_custom_value=not USING_SPACES,
)
load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
chat_interface_container = gr.Column(visible=USING_SPACES)
if chat_model_state is None:
model_name = "Qwen/Qwen2.5-3B-Instruct"
chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
chat_model_name_input.change(
lambda: gr.update(visible=True),
None,
load_chat_model_btn,
show_progress="hidden",
)
load_chat_model_btn.click(
load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
)
with chat_interface_container:
with gr.Row():
@@ -586,22 +784,35 @@ Have a conversation with an AI using your reference voice!
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
scale=3,
)
ref_text_file_chat = gr.File(
label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
)
with gr.Row():
randomize_seed_chat = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Uncheck to use the seed specified.",
scale=3,
)
seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
remove_silence_chat = gr.Checkbox(
label="Remove Silences",
value=True,
)
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
)
system_prompt_chat = gr.Textbox(
label="System Prompt",
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)
chatbot_interface = gr.Chatbot(label="Conversation")
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
with gr.Row():
with gr.Column():
@@ -618,132 +829,101 @@ Have a conversation with an AI using your reference voice!
send_btn_chat = gr.Button("Send Message")
clear_btn_chat = gr.Button("Clear Conversation")
conversation_state = gr.State(
value=[
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
)
# Modify process_audio_input to use model and tokenizer from state
# Modify process_audio_input to generate user input
@gpu_decorator
def process_audio_input(audio_path, text, history, conv_state):
def process_audio_input(conv_state, audio_path, text):
"""Handle audio or text input from user"""
if not audio_path and not text.strip():
return history, conv_state, ""
return conv_state
if audio_path:
text = preprocess_ref_audio_text(audio_path, text)[1]
if not text.strip():
return history, conv_state, ""
return conv_state
conv_state.append({"role": "user", "content": text})
history.append((text, None))
return conv_state
response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
# Use model and tokenizer from state to get text response
@gpu_decorator
def generate_text_response(conv_state, system_prompt):
"""Generate text response from AI"""
system_prompt_state = [{"role": "system", "content": system_prompt}]
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
conv_state.append({"role": "assistant", "content": response})
history[-1] = (text, response)
return history, conv_state, ""
return conv_state
@gpu_decorator
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
"""Generate TTS audio for AI response"""
if not history or not ref_audio:
return None
if not conv_state or not ref_audio:
return None, ref_text, seed_input
last_user_message, last_ai_response = history[-1]
if not last_ai_response:
return None
last_ai_response = conv_state[-1]["content"]
if not last_ai_response or conv_state[-1]["role"] != "assistant":
return None, ref_text, seed_input
audio_result, _, ref_text_out = infer(
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_result, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
last_ai_response,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=0.15,
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result, ref_text_out
return audio_result, ref_text_out, used_seed
def clear_conversation():
"""Reset the conversation"""
return [], [
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
return [], None
def update_system_prompt(new_prompt):
"""Update the system prompt and reset the conversation"""
new_conv_state = [{"role": "system", "content": new_prompt}]
return [], new_conv_state
# Handle audio input
audio_input_chat.stop_recording(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
audio_input_chat,
ref_text_file_chat.upload(
load_text_from_file,
inputs=[ref_text_file_chat],
outputs=[ref_text_chat],
)
# Handle text input
text_input_chat.submit(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
text_input_chat,
)
for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
user_operation(
process_audio_input,
inputs=[chatbot_interface, audio_input_chat, text_input_chat],
outputs=[chatbot_interface],
).then(
generate_text_response,
inputs=[chatbot_interface, system_prompt_chat],
outputs=[chatbot_interface],
).then(
generate_audio_response,
inputs=[
chatbot_interface,
ref_audio_chat,
ref_text_chat,
remove_silence_chat,
randomize_seed_chat,
seed_input_chat,
],
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
).then(
lambda: [None, None],
None,
[audio_input_chat, text_input_chat],
)
# Handle send button
send_btn_chat.click(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
text_input_chat,
)
# Handle clear button
clear_btn_chat.click(
clear_conversation,
outputs=[chatbot_interface, conversation_state],
)
# Handle system prompt change and reset conversation
system_prompt_chat.change(
update_system_prompt,
inputs=system_prompt_chat,
outputs=[chatbot_interface, conversation_state],
)
# Handle clear button or system prompt change and reset conversation
for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
user_operation(
clear_conversation,
outputs=[chatbot_interface, audio_output_chat],
)
with gr.Blocks() as app:
@@ -758,9 +938,9 @@ This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not
The checkpoints currently support English and Chinese.
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
**NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
"""
)
@@ -781,7 +961,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
global tts_model_choice
if new_choice == "Custom": # override in case webpage is refreshed
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
return (
gr.update(visible=True, value=custom_ckpt_path),
gr.update(visible=True, value=custom_vocab_path),
@@ -793,7 +973,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
global tts_model_choice
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
with open(last_used_custom, "w", encoding="utf-8") as f:
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")

View File

@@ -1,5 +1,6 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
@@ -7,12 +8,15 @@ from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
"cuda"
if torch.cuda.is_available()
@@ -40,7 +44,7 @@ target_rms = 0.1
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = globals()[model_cfg.model.backbone]
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
dataset_name = model_cfg.datasets.name
@@ -54,7 +58,8 @@ win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
output_dir = "tests"
@@ -151,7 +156,7 @@ for part in parts_to_edit:
dim=-1,
)
offset = end * target_sample_rate
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
audio = audio.to(device)
edit_mask = edit_mask.to(device)

View File

@@ -4,6 +4,7 @@ import os
import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
@@ -14,6 +15,7 @@ from importlib.resources import files
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
@@ -21,16 +23,14 @@ import numpy as np
import torch
import torchaudio
import tqdm
from huggingface_hub import snapshot_download, hf_hub_download
from huggingface_hub import hf_hub_download
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
from f5_tts.model import CFM
from f5_tts.model.utils import (
get_tokenizer,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}
@@ -128,11 +128,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
except ImportError:
print("You need to follow the README to init submodule and change the BigVGAN source code.")
if is_local:
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
vocoder = bigvgan.BigVGAN.from_pretrained(
"nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
@@ -149,7 +150,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -186,7 +187,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -289,7 +290,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
@@ -302,7 +303,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 15s, clipping short. (1)")
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
@@ -314,7 +315,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 15s, clipping short. (2)")
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
@@ -323,7 +324,7 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 15s, clipping short. (3)")
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")

View File

@@ -1,9 +1,7 @@
from f5_tts.model.cfm import CFM
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.mmdit import MMDiT
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.cfm import CFM
from f5_tts.model.trainer import Trainer

View File

@@ -10,19 +10,18 @@ d - dimension
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNorm_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)

View File

@@ -11,16 +11,15 @@ from __future__ import annotations
import torch
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNorm_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)

View File

@@ -8,24 +8,24 @@ d - dimension
"""
from __future__ import annotations
from typing import Literal
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,
Attention,
AttnProcessor,
ConvNeXtV2Block,
ConvPositionEmbedding,
FeedForward,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)

View File

@@ -270,7 +270,7 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text

View File

@@ -19,6 +19,7 @@ from f5_tts.model import CFM
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer
@@ -51,7 +52,7 @@ class Trainer:
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
cfg_dict: dict = dict(), # training config
model_cfg_dict: dict = dict(), # training config
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
@@ -73,8 +74,8 @@ class Trainer:
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
if not cfg_dict:
cfg_dict = {
if not model_cfg_dict:
model_cfg_dict = {
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
@@ -85,11 +86,11 @@ class Trainer:
"max_grad_norm": max_grad_norm,
"noise_scheduler": noise_scheduler,
}
cfg_dict["gpus"] = self.accelerator.num_processes
model_cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config=cfg_dict,
config=model_cfg_dict,
)
elif self.logger == "tensorboard":
@@ -350,7 +351,7 @@ class Trainer:
progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch+1}/{self.epochs}",
desc=f"Epoch {epoch + 1}/{self.epochs}",
unit="update",
disable=not self.accelerator.is_local_main_process,
initial=progress_bar_initial,
@@ -395,6 +396,9 @@ class Trainer:
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update)
@@ -428,9 +432,7 @@ class Trainer:
torchaudio.save(
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
self.model.train()
self.save_checkpoint(global_update, last=True)

View File

@@ -5,11 +5,10 @@ import random
from collections import defaultdict
from importlib.resources import files
import torch
from torch.nn.utils.rnn import pad_sequence
import jieba
from pypinyin import lazy_pinyin, Style
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
# seed everything

View File

@@ -0,0 +1,3 @@
FROM nvcr.io/nvidia/tritonserver:24.12-py3
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
WORKDIR /workspace

View File

@@ -0,0 +1,69 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
```
### Build Image
Build the docker image from scratch.
```sh
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
```
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Client-Server Mode
```sh
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark using Offline TRT-LLM Mode
```sh
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
```
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -0,0 +1,560 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
"""
import argparse
import json
import os
import time
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
torch.manual_seed(0)
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
parser.add_argument(
"--vocab-file",
required=True,
type=str,
help="vocab file",
)
parser.add_argument(
"--model-path",
required=True,
type=str,
help="model path, to load text embedding",
)
parser.add_argument(
"--tllm-model-dir",
required=True,
type=str,
help="tllm model dir",
)
parser.add_argument(
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
parser.add_argument(
"--vocoder",
default="vocos",
type=str,
help="vocoder name",
)
parser.add_argument(
"--vocoder-trt-engine-path",
default=None,
type=str,
help="vocoder trt engine path",
)
parser.add_argument("--enable-warmup", action="store_true")
parser.add_argument("--remove-input-padding", action="store_true")
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
args = parser.parse_args()
return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
[],
[],
[],
[],
[],
)
for i, item in enumerate(batch):
item_id, prompt_text, target_text = (
item["id"],
item["prompt_text"],
item["target_text"],
)
ids.append(item_id)
reference_target_texts_list.append(prompt_text + target_text)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze()
ref_mel_len = ref_mel.shape[0]
assert ref_mel.shape[1] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
max_seq_len = max(estimated_reference_target_mel_len)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
# Initialize process group with explicit device IDs
dist.init_process_group(
"nccl",
)
return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
if vocoder_name == "vocos":
if vocoder_trt_engine_path is not None:
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
else:
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not implemented yet")
return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
self.session = Session.from_serialized_engine(engine_buffer)
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
def decode(self, mels):
mels = mels.contiguous()
inputs = {"mel": mels}
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
}
ok = self.session.run(inputs, outputs, self.stream)
assert ok, "Runtime execution failed for vae session"
samples = outputs["waveform"]
return samples
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
)
elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
pe_attn_head=1,
text_mask_padding=False,
)
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
)
dataset = load_dataset(
"yuekai/seed_tts",
split=args.split_name,
trust_remote_code=True,
)
def add_estimated_duration(example):
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
estimated_duration = prompt_audio_len * scale_factor
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
return example
dataset = dataset.map(add_estimated_duration)
dataset = dataset.sort("estimated_duration", reverse=True)
if args.use_perf:
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
dataset = datasets.concatenate_datasets(dataset_list_short)
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
else:
# This would disable shuffling
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
)
total_steps = len(dataset)
if args.enable_warmup:
for batch in dataloader:
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
)
elif args.backend_type == "pytorch":
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device)
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
decoding_time = 0
vocoder_time = 0
total_duration = 0
if args.use_perf:
torch.cuda.cudart().cudaProfilerStart()
total_decoding_time = time.time()
for batch in dataloader:
if args.use_perf:
torch.cuda.nvtx.range_push("data sample")
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
ref_mels,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
use_perf=args.use_perf,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if args.vocoder == "vocos":
if args.use_perf:
torch.cuda.nvtx.range_push("vocoder decode")
generated_wave = vocoder.decode(gen_mel_spec).cpu()
if args.use_perf:
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
total_duration += generated_wave.shape[1] / target_sample_rate
vocoder_time += time.time() - vocoder_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
total_decoding_time = time.time() - total_decoding_time
if rank == 0:
progress_bar.close()
rtf = total_decoding_time / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
s += f"batch size: {args.batch_size}\n"
print(s)
with open(f"{args.output_dir}/rtf.txt", "w") as f:
f.write(s)
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,469 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.
Usage:
num_task=2
# For offline F5-TTS
python3 client_grpc.py \
--server-addr localhost \
--model-name f5_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import os
import time
import types
from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
summary_f.write(
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
)
summary_f.write("To learn more about the log, please refer to: \n")
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
summary_f.write(
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
)
summary_f.write(
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
)
summary_f.write(
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
)
summary_f.write(
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
)
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
)
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Grpc port of the triton server, default is 8001",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--huggingface-dataset",
type=str,
default="yuekai/seed_tts",
help="dataset name in huggingface dataset hub",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="dataset split name, default is 'test'",
)
parser.add_argument(
"--manifest-path",
type=str,
default=None,
help="Path to the manifest dir which includes wav.scp trans.txt files.",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
)
parser.add_argument(
"--num-tasks",
type=int,
default=1,
help="Number of concurrent tasks for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-wer",
action="store_true",
default=False,
help="""True to compute WER.
""",
)
parser.add_argument(
"--log-dir",
type=str,
required=False,
default="./tmp",
help="log directory",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Inference batch_size per request for offline mode.",
)
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
return waveform, target_sample_rate
async def send(
manifest_item_list: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
print(f"manifest_item_list: {manifest_item_list}")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
reference_text, target_text = item["reference_text"], item["target_text"]
estimated_target_duration = duration / len(reference_text) * len(target_text)
if padding_duration:
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
padding_duration
* sample_rate
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
input_data_numpy = np.array([reference_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[2].set_data_from_numpy(input_data_numpy)
input_data_numpy = np.array([target_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
sequence_id = 100000000 + i + task_id * 10
start = time.time()
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
audio = response.as_numpy("waveform").reshape(-1)
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, estimated_target_duration))
total_duration += estimated_target_duration
return total_duration, latency_data
def load_manifests(manifest_path):
with open(manifest_path, "r") as f:
manifest_list = []
for line in f:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
{
"audio_filepath": prompt_wav,
"reference_text": prompt_text,
"target_text": gt_text,
"target_audio_path": utt,
}
)
return manifest_list
def split_data(data, k):
n = len(data)
if n < k:
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
k = n
quotient = n // k
remainder = n % k
result = []
start = 0
for i in range(k):
if i < remainder:
end = start + quotient + 1
else:
end = start + quotient
result.append(data[start:end])
start = end
return result
async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient
if args.reference_audio:
args.num_tasks = 1
args.log_interval = 1
manifest_item_list = [
{
"reference_text": args.reference_text,
"target_text": args.target_text,
"audio_filepath": args.reference_audio,
"target_audio_path": "test",
}
]
elif args.huggingface_dataset:
import datasets
dataset = datasets.load_dataset(
args.huggingface_dataset,
split=args.split_name,
trust_remote_code=True,
)
manifest_item_list = []
for i in range(len(dataset)):
manifest_item_list.append(
{
"audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"],
"target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"],
}
)
else:
manifest_item_list = load_manifests(args.manifest_path)
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
tasks = []
start_time = time.time()
for i in range(args.num_tasks):
task = asyncio.create_task(
send(
manifest_item_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
total_duration = 0.0
latency_data = []
for ans in ans_list:
total_duration += ans[0]
latency_data += ans[1]
rtf = elapsed / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
print(s)
if args.manifest_path:
name = Path(args.manifest_path).stem
elif args.split_name:
name = args.split_name
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,143 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import numpy as np
import requests
import soundfile as sf
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-url",
type=str,
default="localhost:8000",
help="Address of the server",
)
parser.add_argument(
"--reference-audio",
type=str,
default="../../infer/examples/basic/basic_ref_en.wav",
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="Some call me nature, others call me mother nature.",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
samples,
reference_text,
target_text,
sample_rate=16000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
"datatype": "INT32",
"data": lengths.tolist(),
},
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
]
}
return data
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
)
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-f5-tts:24.12
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"

View File

@@ -0,0 +1,430 @@
import math
import os
import time
from functools import wraps
from typing import List, Optional
import tensorrt as trt
import tensorrt_llm
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
# Audio tensor case: batch, seq_len, feature_len
# position_ids case: batch, seq_len
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
# Initialize a list to collect valid sequences
valid_sequences = []
for i in range(input_tensor.shape[0]):
valid_length = input_tensor_lengths[i]
valid_sequences.append(input_tensor[i, :valid_length])
# Concatenate all valid sequences along the batch dimension
output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
return output_tensor
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
if use_ema:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
return text_embed_dict
class F5TTS(object):
def __init__(
self,
config,
debug_mode=True,
stream: Optional[torch.cuda.Stream] = None,
tllm_model_dir: Optional[str] = None,
model_path: Optional[str] = None,
vocab_size: Optional[int] = None,
):
self.dtype = config["pretrained_config"]["dtype"]
rank = tensorrt_llm.mpi_rank()
world_size = config["pretrained_config"]["mapping"]["world_size"]
cp_size = config["pretrained_config"]["mapping"]["cp_size"]
tp_size = config["pretrained_config"]["mapping"]["tp_size"]
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
assert pp_size == 1
self.mapping = tensorrt_llm.Mapping(
world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
)
local_rank = rank % self.mapping.gpus_per_node
self.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(self.device)
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
logger.info(f"Loading engine from {engine_file}")
with open(engine_file, "rb") as f:
engine_buffer = f.read()
assert engine_buffer is not None
self.session = Session.from_serialized_engine(engine_buffer)
self.debug_mode = debug_mode
self.inputs = {}
self.outputs = {}
self.buffer_allocated = False
expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
logger.error(
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
)
logger.error(
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
)
logger.error(f"Expected tensor names: {expected_tensor_names}")
logger.error(f"Found tensor names: {found_tensor_names}")
raise RuntimeError("Tensor names in engine are not the same as expected.")
if self.debug_mode:
self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
emb = time_step[i] * emb_factor
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
self.time_expand = time_expand.to(self.device)
self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
def _tensor_dtype(self, name):
# return torch dtype given tensor name for convenience
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
return dtype
def _setup(self, batch_size, seq_len):
for i in range(self.session.engine.num_io_tensors):
name = self.session.engine.get_tensor_name(i)
if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = list(self.session.engine.get_tensor_shape(name))
shape[0] = batch_size
shape[1] = seq_len
self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
self.buffer_allocated = True
def cuda_stream_guard(func):
"""Sync external stream and set current stream to the one bound to the session. Reset on exit."""
@wraps(func)
def wrapper(self, *args, **kwargs):
external_stream = torch.cuda.current_stream()
if external_stream != self.stream:
external_stream.synchronize()
torch.cuda.set_stream(self.stream)
ret = func(self, *args, **kwargs)
if external_stream != self.stream:
self.stream.synchronize()
torch.cuda.set_stream(external_stream)
return ret
return wrapper
@cuda_stream_guard
def forward(
self,
noise: torch.Tensor,
cond: torch.Tensor,
time_expand: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
input_lengths: torch.Tensor,
delta_t: torch.Tensor,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("flow matching")
cfg_strength = 2.0
batch_size = noise.shape[0]
half_batch = batch_size // 2
noise_half = noise[:half_batch] # Store the initial half of noise
input_type = str_dtype_to_torch(self.dtype)
# Keep a copy of the initial tensors
cond = cond.to(input_type)
rope_cos = rope_cos.to(input_type)
rope_sin = rope_sin.to(input_type)
input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
# Instead of iteratively updating noise within a single model context,
# we'll do a single forward pass for each iteration with fresh context setup
for i in range(self.nfe_steps):
# Re-setup the buffers for clean execution
self._setup(batch_size, noise.shape[1])
if not self.buffer_allocated:
raise RuntimeError("Buffer not allocated, please call setup first!")
# Re-create combined noises for this iteration
current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
# Get time step for this iteration
current_time = time_expand[:, i].to(input_type)
# Create fresh input dictionary for this iteration
current_inputs = {
"noise": current_noise,
"cond": cond,
"time": current_time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}
# Update inputs and set shapes
self.inputs.clear() # Clear previous inputs
self.inputs.update(**current_inputs)
self.session.set_shapes(self.inputs)
if use_perf:
torch.cuda.nvtx.range_push(f"execute {i}")
ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
assert ok, "Failed to execute model"
# self.session.context.execute_async_v3(self.stream.cuda_stream)
if use_perf:
torch.cuda.nvtx.range_pop()
# Process results
t_scale = delta_t[i].unsqueeze(0).to(input_type)
# Extract predictions
pred_cond = self.outputs["denoised"][:half_batch]
pred_uncond = self.outputs["denoised"][half_batch:]
# Apply classifier-free guidance with safeguards
guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
# Calculate update for noise
noise_half = noise_half + guidance * t_scale
if use_perf:
torch.cuda.nvtx.range_pop()
return noise_half
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
)
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
text_embedding_drop,
),
dim=-1,
)
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
# Convert estimated_reference_target_mel_len to tensor
input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
# combine above along the batch dimension
inputs = {
"noise": torch.cat((noise, noise), dim=0).contiguous(),
"cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
"time_expand": time_expand,
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
"input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
"delta_t": self.delta_t,
}
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding")
if remove_input_padding:
max_seq_len = inputs["cond"].shape[1]
inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
# for time_expand, convert from B,D to B,T,D by repeat
inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
if use_perf:
torch.cuda.nvtx.range_pop()
start_time = time.time()
denoised = self.forward(**inputs, use_perf=use_perf)
cost_time = time.time() - start_time
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding output")
if remove_input_padding:
denoised_list = []
start_idx = 0
for i in range(batch):
denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
start_idx += inputs["input_lengths"][i]
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
return denoised_list, cost_time
return denoised, cost_time

View File

@@ -0,0 +1,278 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import jieba
import torch
import torch.nn.functional as F
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
from torch.utils.dlpack import from_dlpack, to_dlpack
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
class TritonPythonModel:
def initialize(self, args):
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.head_dim = 64
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
parameters[key] = value["string_value"]
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
self.tllm_model_dir = parameters["tllm_model_dir"]
config_file = os.path.join(self.tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
self.model = F5TTS(
config,
debug_mode=False,
tllm_model_dir=self.tllm_model_dir,
model_path=parameters["model_path"],
vocab_size=self.vocab_size,
)
self.vocoder = parameters["vocoder"]
assert self.vocoder in ["vocos", "bigvgan"]
if self.vocoder == "vocos":
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_audio_sample_rate,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=self.n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(self.device)
self.compute_mel_fn = self.get_vocos_mel_spectrogram
elif self.vocoder == "bigvgan":
self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
def get_vocos_mel_spectrogram(self, waveform):
mel = self.mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
def forward_vocoder(self, mel):
mel = mel.to(torch.float32).contiguous().cpu()
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
inference_request = pb_utils.InferenceRequest(
model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
else:
waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def execute(self, requests):
(
reference_text_list,
target_text_list,
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
for request in requests:
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode("utf-8")
reference_text_list.append(reference_text)
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode("utf-8")
target_text_list.append(target_text)
text = reference_text + target_text
reference_target_texts_list.append(text)
wav = from_dlpack(wav_tensor.to_dlpack())
wav_len = from_dlpack(wav_lens.to_dlpack())
wav_len = wav_len.squeeze()
assert wav.shape[0] == 1, "Only support batch size 1 for now."
wav = wav[:, :wav_len]
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
if self.use_perf:
torch.cuda.nvtx.range_push("compute_mel")
mel_features = self.compute_mel_fn(wav)
if self.use_perf:
torch.cuda.nvtx.range_pop()
mel_features_list.append(mel_features)
reference_mel_len.append(mel_features.shape[1])
estimated_reference_target_mel_len.append(
int(
mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
)
)
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
denoised, cost_time = self.model.sample(
text_pad_sequence,
mel_features,
reference_mel_len_tensor,
estimated_reference_target_mel_len,
remove_input_padding=False,
use_perf=self.use_perf,
)
if self.use_perf:
torch.cuda.nvtx.range_push("vocoder")
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
responses.append(inference_response)
if self.use_perf:
torch.cuda.nvtx.range_pop()
return responses

View File

@@ -0,0 +1,81 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "f5_tts"
backend: "python"
max_batch_size: 4
dynamic_batching {
max_queue_delay_microseconds: 1000
}
parameters [
{
key: "vocab_file"
value: { string_value: "${vocab}"}
},
{
key: "model_path",
value: {string_value:"${model}"}
},
{
key: "tllm_model_dir",
value: {string_value:"${trtllm}"}
},
{
key: "reference_audio_sample_rate",
value: {string_value:"16000"}
},
{
key: "vocoder",
value: {string_value:"${vocoder}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: True
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: True
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,32 @@
name: "vocoder"
backend: "tensorrt"
default_model_filename: "vocoder.plan"
max_batch_size: 4
input [
{
name: "mel"
data_type: TYPE_FP32
dims: [ 100, -1 ]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
dynamic_batching {
preferred_batch_size: [1, 2, 4]
max_queue_delay_microseconds: 1
}
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,199 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .baichuan.model import BaichuanForCausalLM
from .bert.model import (
BertForQuestionAnswering,
BertForSequenceClassification,
BertModel,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaModel,
)
from .bloom.model import BloomForCausalLM, BloomModel
from .chatglm.config import ChatGLMConfig
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
from .cogvlm.config import CogVLMConfig
from .cogvlm.model import CogVLMForCausalLM
from .commandr.model import CohereForCausalLM
from .dbrx.config import DbrxConfig
from .dbrx.model import DbrxForCausalLM
from .deepseek_v1.model import DeepseekForCausalLM
from .deepseek_v2.model import DeepseekV2ForCausalLM
from .dit.model import DiT
from .eagle.model import EagleForCausalLM
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .f5tts.model import F5TTS
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
from .gemma.model import GemmaForCausalLM
from .gpt.config import GPTConfig
from .gpt.model import GPTForCausalLM, GPTModel
from .gptj.config import GPTJConfig
from .gptj.model import GPTJForCausalLM, GPTJModel
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
from .grok.model import GrokForCausalLM
from .llama.config import LLaMAConfig
from .llama.model import LLaMAForCausalLM, LLaMAModel
from .mamba.model import MambaForCausalLM
from .medusa.config import MedusaConfig
from .medusa.model import MedusaForCausalLm
from .mllama.model import MLLaMAModel
from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode
from .mpt.model import MPTForCausalLM, MPTModel
from .nemotron_nas.model import DeciLMForCausalLM
from .opt.model import OPTForCausalLM, OPTModel
from .phi.model import PhiForCausalLM, PhiModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .qwen.model import QWenForCausalLM
from .recurrentgemma.model import RecurrentGemmaForCausalLM
from .redrafter.model import ReDrafterForCausalLM
__all__ = [
"BertModel",
"BertForQuestionAnswering",
"BertForSequenceClassification",
"RobertaModel",
"RobertaForQuestionAnswering",
"RobertaForSequenceClassification",
"BloomModel",
"BloomForCausalLM",
"DiT",
"DeepseekForCausalLM",
"FalconConfig",
"DeepseekV2ForCausalLM",
"FalconForCausalLM",
"FalconModel",
"GPTConfig",
"GPTModel",
"GPTForCausalLM",
"OPTForCausalLM",
"OPTModel",
"LLaMAConfig",
"LLaMAForCausalLM",
"LLaMAModel",
"MedusaConfig",
"MedusaForCausalLm",
"ReDrafterForCausalLM",
"GPTJConfig",
"GPTJModel",
"GPTJForCausalLM",
"GPTNeoXModel",
"GPTNeoXForCausalLM",
"PhiModel",
"PhiConfig",
"Phi3Model",
"Phi3Config",
"PhiForCausalLM",
"Phi3ForCausalLM",
"ChatGLMConfig",
"ChatGLMForCausalLM",
"ChatGLMModel",
"BaichuanForCausalLM",
"QWenConfigQWenForCausalLM",
"QWenModel",
"EncoderModel",
"DecoderModel",
"PretrainedConfig",
"PretrainedModel",
"WhisperEncoder",
"MambaForCausalLM",
"MambaConfig",
"MPTForCausalLM",
"MPTModel",
"SkyworkForCausalLM",
"GemmaConfig",
"GemmaForCausalLM",
"DbrxConfig",
"DbrxForCausalLM",
"RecurrentGemmaForCausalLM",
"CogVLMConfig",
"CogVLMForCausalLM",
"EagleForCausalLM",
"SpeculativeDecodingMode",
"CohereForCausalLM",
"MLLaMAModel",
"F5TTS",
]
MODEL_MAP = {
"GPT2LMHeadModel": GPTForCausalLM,
"GPT2LMHeadCustomModel": GPTForCausalLM,
"GPTBigCodeForCausalLM": GPTForCausalLM,
"Starcoder2ForCausalLM": GPTForCausalLM,
"FuyuForCausalLM": GPTForCausalLM,
"Kosmos2ForConditionalGeneration": GPTForCausalLM,
"JAISLMHeadModel": GPTForCausalLM,
"GPTForCausalLM": GPTForCausalLM,
"NemotronForCausalLM": GPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"BloomForCausalLM": BloomForCausalLM,
"RWForCausalLM": FalconForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"Phi3ForCausalLM": Phi3ForCausalLM,
"Phi3VForCausalLM": Phi3ForCausalLM,
"Phi3SmallForCausalLM": Phi3ForCausalLM,
"PhiMoEForCausalLM": Phi3ForCausalLM,
"MambaForCausalLM": MambaForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"GLMModel": ChatGLMForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForCausalLM": ChatGLMForCausalLM,
"LlamaForCausalLM": LLaMAForCausalLM,
"ExaoneForCausalLM": LLaMAForCausalLM,
"MistralForCausalLM": LLaMAForCausalLM,
"MixtralForCausalLM": LLaMAForCausalLM,
"ArcticForCausalLM": LLaMAForCausalLM,
"Grok1ModelForCausalLM": GrokForCausalLM,
"InternLMForCausalLM": LLaMAForCausalLM,
"InternLM2ForCausalLM": LLaMAForCausalLM,
"MedusaForCausalLM": MedusaForCausalLm,
"ReDrafterForCausalLM": ReDrafterForCausalLM,
"BaichuanForCausalLM": BaichuanForCausalLM,
"BaiChuanForCausalLM": BaichuanForCausalLM,
"SkyworkForCausalLM": LLaMAForCausalLM,
GEMMA_ARCHITECTURE: GemmaForCausalLM,
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
"QWenLMHeadModel": QWenForCausalLM,
"QWenForCausalLM": QWenForCausalLM,
"Qwen2ForCausalLM": QWenForCausalLM,
"Qwen2MoeForCausalLM": QWenForCausalLM,
"Qwen2ForSequenceClassification": QWenForCausalLM,
"Qwen2VLForConditionalGeneration": QWenForCausalLM,
"WhisperEncoder": WhisperEncoder,
"EncoderModel": EncoderModel,
"DecoderModel": DecoderModel,
"DbrxForCausalLM": DbrxForCausalLM,
"RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM,
"CogVLMForCausalLM": CogVLMForCausalLM,
"DiT": DiT,
"DeepseekForCausalLM": DeepseekForCausalLM,
"DeciLMForCausalLM": DeciLMForCausalLM,
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"EagleForCausalLM": EagleForCausalLM,
"CohereForCausalLM": CohereForCausalLM,
"MllamaForConditionalGeneration": MLLaMAModel,
"BertForQuestionAnswering": BertForQuestionAnswering,
"BertForSequenceClassification": BertForSequenceClassification,
"BertModel": BertModel,
"RobertaModel": RobertaModel,
"RobertaForQuestionAnswering": RobertaForQuestionAnswering,
"RobertaForSequenceClassification": RobertaForSequenceClassification,
"F5TTS": F5TTS,
}

View File

@@ -0,0 +1,222 @@
from __future__ import annotations
import os
import sys
from collections import OrderedDict
import tensorrt as trt
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt
from ...functional import Tensor, concat
from ...layers import Linear
from ...module import Module, ModuleList
from ...plugin import current_all_reduce_helper
from ..modeling_utils import PretrainedConfig, PretrainedModel
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)
sys.path.append(parent_dir)
class InputEmbedding(Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x, cond):
x = self.proj(concat([x, cond], dim=-1))
return self.conv_pos_embed(x) + x
class F5TTS(PretrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.dtype = str_dtype_to_trt(config.dtype)
self.time_embed = TimestepEmbedding(config.hidden_size)
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
self.dim = config.hidden_size
self.depth = config.num_hidden_layers
self.transformer_blocks = ModuleList(
[
DiTBlock(
dim=self.dim,
heads=config.num_attention_heads,
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
)
for _ in range(self.depth)
]
)
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
self.proj_out = Linear(config.hidden_size, config.mel_dim)
def forward(
self,
noise, # nosied input audio
cond, # masked cond audio
time, # time step
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
):
t = self.time_embed(time)
x = self.input_embed(noise, cond)
for block in self.transformer_blocks:
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
denoise = self.proj_out(self.norm_out(x, t))
denoise.mark_output("denoised", self.dtype)
return denoise
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
if default_net().plugin_config.remove_input_padding:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, mel_size],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, concat_feature_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
else:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, -1, mel_size],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, -1, concat_feature_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [batch_size_range])]),
)
return {
"noise": noise,
"cond": cond,
"time": time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}

View File

@@ -0,0 +1,412 @@
from __future__ import annotations
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
from ...functional import (
Tensor,
bert_attention,
cast,
chunk,
concat,
constant,
expand,
expand_dims,
expand_dims_like,
expand_mask,
gelu,
matmul,
permute,
shape,
silu,
slice,
softmax,
squeeze,
unsqueeze,
view,
)
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
from ...module import Module
class FeedForward(Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.project_in = Linear(dim, inner_dim)
self.ff = Linear(inner_dim, dim_out)
def forward(self, x):
return self.ff(gelu(self.project_in(x)))
class AdaLayerNormZero(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 6)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
x = self.norm(x)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = x * (ones + scale_msa) + shift_msa
else:
x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZero_Final(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 2)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(silu(emb))
scale, shift = chunk(emb, 2, dim=1)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = self.norm(x) * (ones + scale) + shift
else:
x = self.norm(x) * unsqueeze((ones + scale), 1)
x = x + unsqueeze(shift, 1)
return x
class ConvPositionEmbedding(Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.mish = Mish()
def forward(self, x, mask=None): # noqa: F722
if default_net().plugin_config.remove_input_padding:
x = unsqueeze(x, 0)
x = permute(x, [0, 2, 1])
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
out = permute(x, [0, 2, 1])
if default_net().plugin_config.remove_input_padding:
out = squeeze(out, 0)
return out
class Attention(Module):
def __init__(
self,
processor: AttnProcessor,
dim: int,
heads: int = 16,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim # hidden_size
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.attention_head_size = dim_head
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.tp_size = 1
self.num_attention_heads = heads // self.tp_size
self.num_attention_kv_heads = heads // self.tp_size # 8
self.dtype = str_dtype_to_trt("float32")
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
self.to_q = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_k = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_v = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_dim is not None:
self.to_k_c = Linear(context_dim, self.inner_dim)
self.to_v_c = Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = Linear(context_dim, self.inner_dim)
self.to_out = RowLinear(
self.tp_size * self.num_attention_heads * self.attention_head_size,
dim,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = Linear(self.inner_dim, dim)
def forward(
self,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
c=None, # context c
scale=1.0,
rope=None,
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
else:
return self.processor(
self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
)
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
shape_tensor = concat(
[shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
)
if default_net().plugin_config.remove_input_padding:
assert tensor.ndim() == 2
x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
x1 = expand_dims(x1, 2)
x2 = expand_dims(x2, 2)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 2)
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
else:
assert tensor.ndim() == 3
x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
x1 = expand_dims(x1, 3)
x2 = expand_dims(x2, 3)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 3)
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
rope=None,
) -> torch.FloatTensor:
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
# attention
inner_dim = key.shape[-1]
norm_factor = math.sqrt(attn.attention_head_size)
q_scaling = 1.0 / norm_factor
mask = None
if not default_net().plugin_config.remove_input_padding:
N = shape(x, 1)
B = shape(x, 0)
seq_len_2d = concat([1, N])
max_position_embeddings = 4096
# create position ids
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
mask = tmp_position_ids < tmp_input_lengths # BxL
mask = mask.cast("int32")
if default_net().plugin_config.bert_attention_plugin:
qkv = concat([query, key, value], dim=-1)
# TRT plugin mode
assert input_lengths is not None
if default_net().plugin_config.remove_input_padding:
qkv = qkv.view(concat([-1, 3 * inner_dim]))
max_input_length = constant(
np.zeros(
[
2048,
],
dtype=np.int32,
)
)
else:
max_input_length = None
context = bert_attention(
qkv,
input_lengths,
attn.num_attention_heads,
attn.attention_head_size,
q_scaling=q_scaling,
max_input_length=max_input_length,
)
else:
assert not default_net().plugin_config.remove_input_padding
def transpose_for_scores(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.transpose(1, 2)
return y
def transpose_for_scores_k(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.permute([0, 2, 3, 1])
return y
query = transpose_for_scores(query)
key = transpose_for_scores_k(key)
value = transpose_for_scores(value)
attention_scores = matmul(query, key, use_fp32_acc=False)
if mask is not None:
attention_mask = expand_mask(mask, shape(query, 2))
attention_mask = cast(attention_mask, attention_scores.dtype)
attention_scores = attention_scores + attention_mask
attention_probs = softmax(attention_scores, dim=-1)
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
context = attn.to_out(context)
if mask is not None:
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
mask = expand_dims_like(mask, context)
mask = cast(mask, context.dtype)
context = context * mask
return context
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
def forward(
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
# norm ----> (2,1226,1024)
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
# process attention output for input x
if default_net().plugin_config.remove_input_padding:
x = x + gate_msa * attn_output
else:
x = x + unsqueeze(gate_msa, 1) * attn_output
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
else:
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
ff_output = self.ff(norm)
if default_net().plugin_config.remove_input_padding:
x = x + gate_mlp * ff_output
else:
x = x + unsqueeze(gate_mlp, 1) * ff_output
return x
class TimestepEmbedding(Module):
def __init__(self, dim, freq_embed_dim=256, dtype=None):
super().__init__()
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
def forward(self, timestep):
t_freq = self.mlp1(timestep)
t_freq = silu(t_freq)
t_emb = self.mlp2(t_freq)
return t_emb

View File

@@ -0,0 +1,24 @@
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
# torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
packaging>=24.2

View File

@@ -0,0 +1,110 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
if [ -z "$model" ]; then
echo "Model is none, using default model F5TTS_Base"
model=F5TTS_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Testing http client"
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "TRT-LLM: offline decoding benchmark test"
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
pip install -r requirements-pytorch.txt
batch_size=1
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
fi

View File

@@ -0,0 +1,248 @@
# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
# Copyright (c) 2020 Shimin Zhang
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window
support_clp_op = None
if th.__version__ >= "1.7.0":
from torch.fft import rfft as fft
support_clp_op = True
else:
from torch import rfft as fft
class STFT(th.nn.Module):
def __init__(
self,
win_len=1024,
win_hop=512,
fft_len=1024,
enframe_mode="continue",
win_type="hann",
win_sqrt=False,
pad_center=True,
):
"""
Implement of STFT using 1D convolution and 1D transpose convolutions.
Implement of framing the signal in 2 ways, `break` and `continue`.
`break` method is a kaldi-like framing.
`continue` method is a librosa-like framing.
More information about `perfect reconstruction`:
1. https://ww2.mathworks.cn/help/signal/ref/stft.html
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
Args:
win_len (int): Number of points in one frame. Defaults to 1024.
win_hop (int): Number of framing stride. Defaults to 512.
fft_len (int): Number of DFT points. Defaults to 1024.
enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
win_type (str, optional): The type of window to create. Defaults to 'hann'.
win_sqrt (bool, optional): using square root window. Defaults to True.
pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
"""
super(STFT, self).__init__()
assert enframe_mode in ["break", "continue"]
assert fft_len >= win_len
self.win_len = win_len
self.win_hop = win_hop
self.fft_len = fft_len
self.mode = enframe_mode
self.win_type = win_type
self.win_sqrt = win_sqrt
self.pad_center = pad_center
self.pad_amount = self.fft_len // 2
en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
self.register_buffer("en_k", en_k)
self.register_buffer("fft_k", fft_k)
self.register_buffer("ifft_k", ifft_k)
self.register_buffer("ola_k", ola_k)
def __init_kernel__(self):
"""
Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
** enframe_kernel: Using conv1d layer and identity matrix.
** fft_kernel: Using linear layer for matrix multiplication. In fact,
enframe_kernel and fft_kernel can be combined, But for the sake of
readability, I took the two apart.
** ifft_kernel, pinv of fft_kernel.
** overlap-add kernel, just like enframe_kernel, but transposed.
Returns:
tuple: four kernels.
"""
enframed_kernel = th.eye(self.fft_len)[:, None, :]
if support_clp_op:
tmp = fft(th.eye(self.fft_len))
fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
else:
fft_kernel = fft(th.eye(self.fft_len), 1)
if self.mode == "break":
enframed_kernel = th.eye(self.win_len)[:, None, :]
fft_kernel = fft_kernel[: self.win_len]
fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
window = get_window(self.win_type, self.win_len)
self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
window = th.FloatTensor(window)
if self.mode == "continue":
left_pad = (self.fft_len - self.win_len) // 2
right_pad = left_pad + (self.fft_len - self.win_len) % 2
window = F.pad(window, (left_pad, right_pad))
if self.win_sqrt:
self.padded_window = window
window = th.sqrt(window)
else:
self.padded_window = window**2
fft_kernel = fft_kernel.T * window
ifft_kernel = ifft_kernel * window
ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
if self.mode == "continue":
ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
def is_perfect(self):
"""
Whether the parameters win_len, win_hop and win_sqrt
obey constants overlap-add(COLA)
Returns:
bool: Return true if parameters obey COLA.
"""
return self.perfect_reconstruct and self.pad_center
def transform(self, inputs, return_type="complex"):
"""Take input data (audio) to STFT domain.
Args:
inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
return_type (str, optional): return (mag, phase) when `magphase`,
return (real, imag) when `realimag` and complex(real, imag) when `complex`.
Defaults to 'complex'.
Returns:
tuple: (mag, phase) when `magphase`, return (real, imag) when
`realimag`. Defaults to 'complex', each elements with shape
[num_batch, num_frequencies, num_frames]
"""
assert return_type in ["magphase", "realimag", "complex"]
if inputs.dim() == 2:
inputs = th.unsqueeze(inputs, 1)
self.num_samples = inputs.size(-1)
if self.pad_center:
inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
outputs = th.transpose(enframe_inputs, 1, 2)
outputs = F.linear(outputs, self.fft_k)
outputs = th.transpose(outputs, 1, 2)
dim = self.fft_len // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
if return_type == "realimag":
return real, imag
elif return_type == "complex":
assert support_clp_op
return th.complex(real, imag)
else:
mags = th.sqrt(real**2 + imag**2)
phase = th.atan2(imag, real)
return mags, phase
def inverse(self, input1, input2=None, input_type="magphase"):
"""Call the inverse STFT (iSTFT), given tensors produced
by the `transform` function.
Args:
input1 (tensors): Magnitude/Real-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input2 (tensors): Phase/Imag-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input_type (str, optional): Mathematical meaning of input tensor's.
Defaults to 'magphase'.
Returns:
tensors: Reconstructed audio given magnitude and phase. Of
shape [num_batch, num_samples]
"""
assert input_type in ["magphase", "realimag"]
if input_type == "realimag":
real, imag = None, None
if support_clp_op and th.is_complex(input1):
real, imag = input1.real, input1.imag
else:
real, imag = input1, input2
else:
real = input1 * th.cos(input2)
imag = input1 * th.sin(input2)
inputs = th.cat([real, imag], dim=1)
outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
t = t.to(inputs.device)
coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
num_frames = input1.size(-1)
num_samples = num_frames * self.win_hop
rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
outputs = outputs[..., rm_start:rm_end]
coff = coff[..., rm_start:rm_end]
coffidx = th.where(coff > 1e-8)
outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
return outputs.squeeze(dim=1)
def forward(self, inputs):
"""Take input data (audio) to STFT domain and then back to audio.
Args:
inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
Returns:
tensor: Reconstructed audio given magnitude and phase.
Of shape [num_batch, num_samples]
"""
mag, phase = self.transform(inputs)
rec_wav = self.inverse(mag, phase)
return rec_wav

View File

@@ -0,0 +1,358 @@
import argparse
import json
import os
import re
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
import safetensors.torch
import torch
from tensorrt_llm import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=1)
return split_v.contiguous()
def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=0)
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
choices=[
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
args = parser.parse_args()
return args
def convert_timm_dit(args, mapping, dtype="float32"):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
weights = dict()
for name, param in model_params.items():
if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
else:
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
scale_factor = math.pow(64, -0.25)
for k, v in weights.items():
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
weights[k] *= scale_factor
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
weights[k] *= scale_factor
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Weights loaded. Total time: {t}")
return weights
def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"ff_mult": 2,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
"tp_size": args.tp_size,
"pp_size": args.pp_size,
},
}
if args.fp8_linear:
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
def covert_and_save(args, rank):
if rank == 0:
save_config(args)
mapping = Mapping(
world_size=args.cp_size * args.tp_size * args.pp_size,
rank=rank,
cp_size=args.cp_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
def main():
args = parse_arguments()
world_size = args.cp_size * args.tp_size * args.pp_size
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
return
print("start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Total time of converting checkpoints: {t}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,138 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torch.nn as nn
from conv_stft import STFT
from huggingface_hub import hf_hub_download
from vocos import Vocos
opset_version = 17
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--vocoder",
type=str,
default="vocos",
choices=["vocos", "bigvgan"],
help="Vocoder to export",
)
parser.add_argument(
"--output-path",
type=str,
default="./vocos_vocoder.onnx",
help="Output path",
)
return parser.parse_args()
class ISTFTHead(nn.Module):
def __init__(self, n_fft: int, hop_length: int):
super().__init__()
self.out = None
self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
def forward(self, x: torch.Tensor):
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2)
real = mag * torch.cos(p)
imag = mag * torch.sin(p)
audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
return audio
class VocosVocoder(nn.Module):
def __init__(self, vocos_vocoder):
super(VocosVocoder, self).__init__()
self.vocos_vocoder = vocos_vocoder
istft_head_out = self.vocos_vocoder.head.out
n_fft = self.vocos_vocoder.head.istft.n_fft
hop_length = self.vocos_vocoder.head.istft.hop_length
istft_head_for_export = ISTFTHead(n_fft, hop_length)
istft_head_for_export.out = istft_head_out
self.vocos_vocoder.head = istft_head_for_export
def forward(self, mel):
waveform = self.vocos_vocoder.decode(mel)
return waveform
def export_VocosVocoder(vocos_vocoder, output_path, verbose):
vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
vocos_vocoder.eval()
dummy_batch_size = 8
dummy_input_length = 500
dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
with torch.no_grad():
dummy_waveform = vocos_vocoder(mel=dummy_mel)
print(dummy_waveform.shape)
dummy_input = dummy_mel
torch.onnx.export(
vocos_vocoder,
dummy_input,
output_path,
opset_version=opset_version,
do_constant_folding=True,
input_names=["mel"],
output_names=["waveform"],
dynamic_axes={
"mel": {0: "batch_size", 2: "input_length"},
"waveform": {0: "batch_size", 1: "output_length"},
},
verbose=verbose,
)
print("Exported to {}".format(output_path))
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not supported yet")
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
if __name__ == "__main__":
args = get_args()
vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
if args.vocoder == "vocos":
export_VocosVocoder(vocoder, args.output_path, verbose=False)

View File

@@ -0,0 +1,43 @@
#!/bin/bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
ENGINE_PATH=$2
echo "ONNX_PATH: $ONNX_PATH"
echo "ENGINE_PATH: $ENGINE_PATH"
PRECISION="fp32"
MIN_BATCH_SIZE=1
OPT_BATCH_SIZE=1
MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
${TRTEXEC} \
--minShapes="mel:${MEL_MIN_SHAPE}" \
--optShapes="mel:${MEL_OPT_SHAPE}" \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}

View File

@@ -0,0 +1,36 @@
#! /usr/bin/env python3
from argparse import ArgumentParser
from string import Template
def main(file_path, substitutions, in_place, participant_ids):
with open(file_path) as f:
pbtxt = Template(f.read())
sub_dict = {"max_queue_size": 0}
sub_dict["participant_ids"] = participant_ids
for sub in substitutions.split(","):
key, value = sub.split(":")
sub_dict[key] = value
pbtxt = pbtxt.safe_substitute(sub_dict)
if in_place:
with open(file_path, "w") as f:
f.write(pbtxt)
else:
print(pbtxt)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument(
"substitutions",
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
)
parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
args = parser.parse_args()
main(**vars(args))

View File

@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")

View File

@@ -1,12 +1,13 @@
import sys
import os
import sys
sys.path.append(os.getcwd())
from f5_tts.model import CFM, DiT
import torch
import thop
import torch
from f5_tts.model import CFM, DiT
""" ~155M """

View File

@@ -1,10 +1,12 @@
import socket
import asyncio
import pyaudio
import numpy as np
import logging
import socket
import time
import numpy as np
import pyaudio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,6 @@
import argparse
import gc
import logging
import numpy as np
import queue
import socket
import struct
@@ -10,20 +9,22 @@ import traceback
import wave
from importlib.resources import files
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.model.backbones.dit import DiT # noqa: F401. used for config
from f5_tts.infer.utils_infer import (
chunk_text,
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_batch_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -80,7 +81,7 @@ class TTSStreamingProcessor:
else "cpu"
)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
self.model_cls = globals()[model_cfg.model.backbone]
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate

View File

@@ -51,7 +51,11 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
If use tensorboard as logger, install it first with `pip install tensorboard`.
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
### 3. W&B Logging

View File

@@ -1,12 +1,13 @@
import os
import sys
import signal
import subprocess # For invoking ffprobe
import shutil
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import sys
from contextlib import contextmanager
sys.path.append(os.getcwd())
import argparse
@@ -16,12 +17,10 @@ from importlib.resources import files
from pathlib import Path
import torchaudio
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
@@ -122,7 +121,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
for future in tqdm(
chunk_futures,
total=len(chunk),
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
):
try:
result = future.result()
@@ -233,7 +232,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
dataset_name = out_dir.stem
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):

View File

@@ -7,20 +7,18 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
repetition_found,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
out_zh = {
@@ -198,7 +196,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:

View File

@@ -0,0 +1,94 @@
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
# prepares Emilia dataset with the new format w/ Emilia-YODAS
import json
import os
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import repetition_found
# Define filters for exclusion
out_en = set()
en_filters = ["ا", "", ""]
def process_audio_directory(audio_dir):
sub_result, durations, vocab_set = [], [], set()
bad_case_en = 0
for file in audio_dir.iterdir():
if file.suffix == ".json":
with open(file, "r") as f:
obj = json.load(f)
text = obj["text"]
if any(f in text for f in en_filters) or repetition_found(text, length=4):
bad_case_en += 1
continue
duration = obj["duration"]
audio_file = file.with_suffix(".mp3")
if audio_file.exists():
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result, duration_list, text_vocab_set = [], [], set()
total_bad_case_en = 0
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
dataset_path = Path(dataset_dir)
for sub_dir in dataset_path.iterdir():
if sub_dir.is_dir():
futures.append(executor.submit(process_audio_directory, sub_dir))
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_en = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_en += bad_case_en
executor.shutdown()
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"For {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "char"
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
dataset_name = f"Emilia_EN_{tokenizer}"
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
main()

View File

@@ -1,15 +1,17 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def deal_with_audio_dir(audio_dir):
@@ -72,7 +74,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":

View File

@@ -1,14 +1,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def main():
@@ -50,7 +52,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":

View File

@@ -4,15 +4,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from tqdm import tqdm
import torchaudio
from datasets import Dataset
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin

View File

@@ -5,9 +5,9 @@ from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #
@@ -40,15 +40,15 @@ def parse_args():
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
@@ -65,7 +65,7 @@ def parse_args():
action="store_true",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
parser.add_argument(
"--bnb_optimizer",
action="store_true",

View File

@@ -1,14 +1,12 @@
import gc
import json
import numpy as np
import os
import platform
import psutil
import queue
import random
import re
import signal
import shutil
import signal
import subprocess
import sys
import tempfile
@@ -16,21 +14,23 @@ import threading
import time
from glob import glob
from importlib.resources import files
from scipy.io import wavfile
import click
import gradio as gr
import librosa
import numpy as np
import psutil
import torch
import torchaudio
from cached_path import cached_path
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import load_file, save_file
from scipy.io import wavfile
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.infer.utils_infer import transcribe
from f5_tts.model.utils import convert_char_to_pinyin
training_process = None
@@ -120,11 +120,11 @@ def load_settings(project_name):
default_settings = {
"exp_name": "F5TTS_v1_Base",
"learning_rate": 1e-5,
"batch_size_per_gpu": 1,
"batch_size_type": "sample",
"batch_size_per_gpu": 3200,
"batch_size_type": "frame",
"max_samples": 64,
"grad_accumulation_steps": 4,
"max_grad_norm": 1,
"grad_accumulation_steps": 1,
"max_grad_norm": 1.0,
"epochs": 100,
"num_warmup_updates": 100,
"save_per_updates": 500,
@@ -134,10 +134,12 @@ def load_settings(project_name):
"file_checkpoint_train": "",
"tokenizer_type": "pinyin",
"tokenizer_file": "",
"mixed_precision": "none",
"logger": "wandb",
"mixed_precision": "fp16",
"logger": "none",
"bnb_optimizer": False,
}
if device == "mps":
default_settings["mixed_precision"] = "none"
# Load settings from file if it exists
if os.path.isfile(file_setting):
@@ -361,27 +363,27 @@ def terminate_process(pid):
def start_training(
dataset_name="",
exp_name="F5TTS_v1_Base",
learning_rate=1e-5,
batch_size_per_gpu=1,
batch_size_type="sample",
max_samples=64,
grad_accumulation_steps=4,
max_grad_norm=1.0,
epochs=100,
num_warmup_updates=100,
save_per_updates=500,
keep_last_n_checkpoints=-1,
last_per_updates=100,
finetune=True,
file_checkpoint_train="",
tokenizer_type="pinyin",
tokenizer_file="",
mixed_precision="fp16",
stream=False,
logger="wandb",
ch_8bit_adam=False,
dataset_name,
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
stream,
logger,
ch_8bit_adam,
):
global training_process, tts_api, stop_signal
@@ -458,7 +460,10 @@ def start_training(
cmd += f" --tokenizer {tokenizer_type}"
cmd += f" --log_samples --logger {logger}"
if logger != "none":
cmd += f" --logger {logger}"
cmd += " --log_samples"
if ch_8bit_adam:
cmd += " --bnb_optimizer"
@@ -515,7 +520,7 @@ def start_training(
training_process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
)
yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
stdout_queue = queue.Queue()
stderr_queue = queue.Queue()
@@ -584,7 +589,11 @@ def start_training(
gr.update(interactive=True),
)
else:
yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
yield (
"Training complete or paused ...",
gr.update(interactive=False),
gr.update(interactive=True),
)
break
# Small sleep to prevent CPU thrashing
@@ -598,9 +607,9 @@ def start_training(
time.sleep(1)
if training_process is None:
text_info = "train stop"
text_info = "Train stopped !"
else:
text_info = "train complete !"
text_info = "Train complete at end !"
except Exception as e: # Catch all exceptions
# Ensure that we reset the training process variable in case of an error
@@ -615,11 +624,11 @@ def stop_training():
global training_process, stop_signal
if training_process is None:
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
terminate_process_tree(training_process.pid)
# training_process = None
stop_signal = True
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
def get_list_projects():
@@ -958,21 +967,23 @@ def calculate_train(
)
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
try:
checkpoint = torch.load(checkpoint_path, weights_only=True)
print("Original Checkpoint Keys:", checkpoint.keys())
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
if ema_model_state_dict is None:
return "No 'ema_model_state_dict' found in the checkpoint."
to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
try:
model_state_dict_to_retain = checkpoint[to_retain]
except KeyError:
return f"{to_retain} not found in the checkpoint."
if safetensors:
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
save_file(ema_model_state_dict, new_checkpoint_path)
save_file(model_state_dict_to_retain, new_checkpoint_path)
else:
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
torch.save(new_checkpoint, new_checkpoint_path)
return f"New checkpoint saved at: {new_checkpoint_path}"
@@ -1013,7 +1024,10 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
torch.save(ckpt, new_ckpt_path)
if new_ckpt_path.endswith(".safetensors"):
save_file(ema_sd, new_ckpt_path)
elif new_ckpt_path.endswith(".pt"):
torch.save(ckpt, new_ckpt_path)
return vocab_new
@@ -1125,7 +1139,7 @@ def vocab_check(project_name):
info = "You can train using your language !"
else:
vocab_miss = ",".join(miss_symbols)
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
return info, vocab_miss
@@ -1212,6 +1226,9 @@ def infer(
print("update >> ", device_test, file_checkpoint, use_ema)
if seed == -1: # -1 used for random
seed = None
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
ref_file=ref_audio,
@@ -1430,9 +1447,9 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
)
audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
txt_lang = gr.Text(label="Language", value="English")
txt_lang = gr.Textbox(label="Language", value="English")
bt_transcribe = bt_create = gr.Button("Transcribe")
txt_info_transcribe = gr.Text(label="Info", value="")
txt_info_transcribe = gr.Textbox(label="Info", value="")
bt_transcribe.click(
fn=transcribe_all,
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
@@ -1443,7 +1460,7 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
random_sample_transcribe = gr.Button("Random Sample")
with gr.Row():
random_text_transcribe = gr.Text(label="Text")
random_text_transcribe = gr.Textbox(label="Text")
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
random_sample_transcribe.click(
@@ -1458,7 +1475,7 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
```""")
check_button = gr.Button("Check Vocab")
txt_info_check = gr.Text(label="Info", value="")
txt_info_check = gr.Textbox(label="Info", value="")
gr.Markdown("""```plaintext
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
@@ -1478,7 +1495,7 @@ Using the extended model, you can finetune to a new language that is missing sym
txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
extend_button = gr.Button("Extend")
txt_info_extend = gr.Text(label="Info", value="")
txt_info_extend = gr.Textbox(label="Info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
@@ -1518,8 +1535,8 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
bt_prepare = bt_create = gr.Button("Prepare")
txt_info_prepare = gr.Text(label="Info", value="")
txt_vocab_prepare = gr.Text(label="Vocab", value="")
txt_info_prepare = gr.Textbox(label="Info", value="")
txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
bt_prepare.click(
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
@@ -1528,7 +1545,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
random_sample_prepare = gr.Button("Random Sample")
with gr.Row():
random_text_prepare = gr.Text(label="Tokenizer")
random_text_prepare = gr.Textbox(label="Tokenizer")
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
random_sample_prepare.click(
@@ -1541,50 +1558,60 @@ The auto-setting is still experimental. Set a large value of epoch if not sure;
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
```""")
with gr.Row():
bt_calculate = bt_create = gr.Button("Auto Settings")
exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
tokenizer_file = gr.Textbox(label="Tokenizer File")
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="Finetune")
lb_samples = gr.Label(label="Samples")
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
bt_calculate = bt_create = gr.Button("Auto Settings")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="Finetune", value=True)
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
epochs = gr.Number(label="Epochs")
learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
max_grad_norm = gr.Number(label="Max Gradient Norm")
num_warmup_updates = gr.Number(label="Warmup Updates")
with gr.Row():
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
batch_size_type = gr.Radio(
label="Batch Size Type",
choices=["frame", "sample"],
info="frame is calculated as seconds * sampling_rate / hop_length",
)
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
grad_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
)
max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
with gr.Row():
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=3200)
max_samples = gr.Number(label="Max Samples", value=64)
with gr.Row():
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
with gr.Row():
epochs = gr.Number(label="Epochs", value=100)
num_warmup_updates = gr.Number(label="Warmup Updates", value=100)
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=500)
save_per_updates = gr.Number(
label="Save per Updates",
info="Save intermediate checkpoints every N updates",
minimum=10,
)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=-1,
step=1,
precision=0,
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
minimum=-1,
)
last_per_updates = gr.Number(label="Last per Updates", value=100)
last_per_updates = gr.Number(
label="Last per Updates",
info="Save latest checkpoint with suffix _last.pt every N updates",
minimum=10,
)
gr.Radio(label="") # placeholder
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="fp16")
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
with gr.Column():
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
if projects_selelect is not None:
(
@@ -1631,7 +1658,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
ch_8bit_adam.value = bnb_optimizer_value
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
txt_info_train = gr.Text(label="Info", value="")
txt_info_train = gr.Textbox(label="Info", value="")
list_audios, select_audio = get_audio_project(projects_selelect, False)
@@ -1760,7 +1787,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
with gr.TabItem("Test Model"):
gr.Markdown("""```plaintext
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
```""")
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
@@ -1770,11 +1797,13 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
with gr.Row():
nfe_step = gr.Number(label="NFE Step", value=32)
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
seed = gr.Number(label="Seed", value=-1, minimum=-1)
seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
remove_silence = gr.Checkbox(label="Remove Silence")
ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
with gr.Row():
ch_use_ema = gr.Checkbox(
label="Use EMA", value=True, info="Turn off at early stage might offer better results"
)
cm_checkpoint = gr.Dropdown(
choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
)
@@ -1782,20 +1811,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
random_sample_infer = gr.Button("Random Sample")
ref_text = gr.Textbox(label="Ref Text")
ref_audio = gr.Audio(label="Audio Ref", type="filepath")
gen_text = gr.Textbox(label="Gen Text")
ref_text = gr.Textbox(label="Reference Text")
ref_audio = gr.Audio(label="Reference Audio", type="filepath")
gen_text = gr.Textbox(label="Text to Generate")
random_sample_infer.click(
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
)
with gr.Row():
txt_info_gpu = gr.Textbox("", label="Device")
seed_info = gr.Text(label="Seed :")
check_button_infer = gr.Button("Infer")
txt_info_gpu = gr.Textbox("", label="Inference on Device :")
seed_info = gr.Textbox(label="Used Random Seed :")
check_button_infer = gr.Button("Inference")
gen_audio = gr.Audio(label="Audio Gen", type="filepath")
gen_audio = gr.Audio(label="Generated Audio", type="filepath")
check_button_infer.click(
fn=infer,
@@ -1822,14 +1851,16 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
gr.Markdown("""```plaintext
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
```""")
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
ch_safetensors = gr.Checkbox(label="Safetensors", value="")
txt_info_reduse = gr.Text(label="Info", value="")
reduse_button = gr.Button("Reduce")
txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
with gr.Row():
ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
txt_info_reduse = gr.Textbox(label="Info", value="")
reduse_button = gr.Button("Prune")
reduse_button.click(
fn=extract_and_save_ema_model,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
fn=prune_checkpoint,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
outputs=[txt_info_reduse],
)

View File

@@ -6,68 +6,69 @@ from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, DiT, UNetT, Trainer # noqa: F401. used for config
from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(cfg):
model_cls = globals()[cfg.model.backbone]
model_arc = cfg.model.arch
tokenizer = cfg.model.tokenizer
mel_spec_type = cfg.model.mel_spec.mel_spec_type
def main(model_cfg):
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = cfg.datasets.name
tokenizer_path = model_cfg.datasets.name
else:
tokenizer_path = cfg.model.tokenizer_path
tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
model = CFM(
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=cfg.model.mel_spec,
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=cfg.optim.epochs,
learning_rate=cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
batch_size_per_gpu=cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type,
max_samples=cfg.datasets.max_samples,
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
max_grad_norm=cfg.optim.max_grad_norm,
logger=cfg.ckpts.logger,
epochs=model_cfg.optim.epochs,
learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=cfg.ckpts.last_per_updates,
log_samples=cfg.ckpts.log_samples,
bnb_optimizer=cfg.optim.bnb_optimizer,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
cfg_dict=OmegaConf.to_container(cfg, resolve=True),
is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=model_cfg.model.vocoder.local_path,
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=cfg.datasets.num_workers,
num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)