89 Commits
0.6.2 ... 1.1.4

Author SHA1 Message Date
SWivid
ac79d0ec1e v1.1.4 2025-05-05 04:05:25 +08:00
SWivid
dad398c0c1 Bug Fix #1015
Ensure custom config hashable in
2025-05-05 03:55:05 +08:00
SWivid
3d969bf78d minor fix for backward compatibility to gradio multistyle feature 2025-05-05 02:07:19 +08:00
SWivid
7c741c05f9 v1.1.3 better infer_gradio with cherrypick and cache support 2025-05-05 01:42:41 +08:00
SWivid
6d1a1e886a formatting, sorting 2025-05-05 01:41:28 +08:00
SWivid
b4efcd836a Add cache feature. Retrieve previous generated segments, default cache size 100 2025-05-05 01:37:22 +08:00
SWivid
818b868fab Update infer_gradio.py. Enable seed selecting for multistyle generation 2025-05-05 00:58:24 +08:00
SWivid
e6fee5e9ba Update infer_gradio.py
Use gr.Column to ensure backward compatibility

Remove height attr from gr.File to avoid possible malposition across versions
2025-05-04 09:25:41 +08:00
Yushen CHEN
2de214c122 Merge pull request #1014 from fakerybakery/fix-gradio-app-250503
Fix Gradio app
2025-05-04 09:14:32 +08:00
mrfakename
2999f642ce Row -> Column 2025-05-03 17:59:07 -07:00
mrfakename
03cff73343 remove equal_height requirement
Seems to break Gradio demo.
2025-05-03 17:57:41 -07:00
mrfakename
63c513840d fix gradio app 2025-05-03 17:56:21 -07:00
SWivid
3e6b6c0c0c update infer_gradio.py. rename for consistency 2025-05-04 08:04:00 +08:00
SWivid
f00ac4d06b fix infer-gradio chat feature etc. 2025-05-04 08:00:16 +08:00
Yushen CHEN
b0658bfd24 Merge pull request #1013 from petermg/main
Update infer_gradio.py
2025-05-04 03:33:22 +08:00
petermg
0cae51d646 Update infer_gradio.py
Modified formatting
2025-05-03 12:07:58 -07:00
petermg
95976041f2 Update infer_gradio.py
Added "randomize seed" checkmark and option to specify seed showing last seed used and can manually enter the desired seed number.
2025-05-03 11:38:50 -07:00
petermg
ba1bf74215 Update infer_gradio.py
Modified it so that when you upload a text file, the text of that file will show in the text input window. Also made the text file upload window show up BELOW the text input display window.
2025-05-03 11:22:07 -07:00
petermg
536c29ac57 Update infer_gradio.py
Modified the UI to accept txt files as inputs
2025-05-02 12:45:39 -07:00
SWivid
c4c61b0110 v1.1.2 several updates
add data prepare script recipe for emilia-yodas; fix speech_edit.py; fix tensorrt-llm server code-switch
2025-05-02 03:13:33 +08:00
SWivid
5f80fec160 fix speech_edit.py 2025-04-26 02:10:39 +08:00
Yushen CHEN
178cb8afe6 Merge pull request #986 from fakerybakery/emilia-v2
Add processing script for new Emilia dataset format
2025-04-19 14:16:37 +08:00
mrfakename
761c7ed938 Add processing script for new Emilia dataset format 2025-04-18 20:56:31 -07:00
Yushen CHEN
13fd6f8e07 Merge pull request #971 from tbxark-fork/main
chore: Update the model checkpoint path to use the cache path.
2025-04-14 15:54:50 +08:00
tbxark
b2284b6cff chore: Update the model checkpoint path to use the cache path. 2025-04-14 11:28:48 +08:00
SWivid
4b4359bc39 finetune_gradio not to use fp16 by default for mps device 2025-04-03 22:33:21 +08:00
SWivid
fe5c562212 v1.1.1 add benchmark and trtllm offline code 2025-04-03 18:33:48 +08:00
Yushen CHEN
2374f8ec39 Merge pull request #948 from yuekaizhang/trtllm_benchmark
[TRT-LLM] add benchmark code
2025-04-03 18:27:21 +08:00
Yuekai Zhang
f4f10bff6c fix comment 2025-04-03 02:44:59 -07:00
Yuekai Zhang
9771ec6a3a add benchmark code 2025-04-03 02:42:40 -07:00
SWivid
4b3cd13382 Update README.md 2025-04-03 15:04:42 +08:00
SWivid
25b3291715 Update README.md 2025-04-03 14:41:52 +08:00
SWivid
16c480a61d v1.1.0 Support GPU Depolyment with Triton and TensorRT-LLM #944 2025-04-03 14:37:58 +08:00
SWivid
d9dfbe47cc Update README.md 2025-04-03 14:36:22 +08:00
Yushen CHEN
d1f6c95fe8 Merge pull request #944 from yuekaizhang/triton
Support GPU Depolyment Solution with Triton and TensorRT-LLM
2025-04-03 13:42:37 +08:00
root
2428d01a56 remove empty lines 2025-04-03 05:25:29 +00:00
root
9401842930 add http client 2025-04-03 05:14:03 +00:00
root
eca56943ec fix docker compose issue 2025-04-03 04:31:33 +00:00
root
ae51cc3d34 fix bug 2025-04-03 04:25:43 +00:00
root
4681a1c177 remove annotation 2025-04-03 02:35:26 +00:00
root
5b178397e0 remove unused codes 2025-04-03 02:34:28 +00:00
Yuekai Zhang
2724f9f101 add Nvidia Triton TensorRT-LLM solution 2025-04-02 19:04:45 -07:00
SWivid
7258b09529 v1.0.10 support custom chat model 2025-03-31 21:15:26 +08:00
SWivid
784e3862b4 add microsoft/Phi-4-mini-instruct to chat model list #937 2025-03-31 21:14:39 +08:00
SWivid
6f6968b034 formatting 2025-03-31 19:45:38 +08:00
maximechen
9bd2d13be1 Merge branch 'huanglizhuo-feat/support-custom-chat-model' 2025-03-31 19:22:08 +08:00
maximechen
b7c41af9cd reorganize and distinguish behavior from local and space 2025-03-31 19:11:52 +08:00
huanglizhuo
eaa7fd8a01 Reapply pre-commit hooks 2025-03-29 20:58:42 +09:00
Yushen CHEN
f34465d118 v1.0.9 several fixes 2025-03-28 23:12:13 +08:00
lizhuo
393993321d fix: use pydantic<=2.10.6 to address dependency conflict with gradio-app #930 2025-03-28 23:10:41 +08:00
lizhuo
29d3326bed update: JA latest HF path in SHARED.md #928
* fix: update japanese latest hf path
* update the huggingface url
2025-03-28 22:36:17 +08:00
Zhikang Niu
67e43dc0fb Merge pull request #926 from huanglizhuo/fix/shared-file-path
fix the SHARED.md file path
2025-03-28 17:14:54 +08:00
huanglizhuo
8469025b1c fix the shared.md file path 2025-03-28 17:52:08 +09:00
Zhikang Niu
5bd8cd7aed update: better save last & per ckpt logic #924
Co-authored-by: Yushen CHEN <45333109+SWivid@users.noreply.github.com>
2025-03-28 13:53:12 +08:00
SWivid
7236536f9a update utils_infer.py 2025-03-25 17:24:20 +08:00
SWivid
6b7f6eefdc fix typo in trainer.py with 4ae5347282 formatting #909 2025-03-25 16:17:03 +08:00
SWivid
b9156c0ad5 v1.0.8 fix a fatal bug with log_samples since 37eb3b50da 2025-03-25 07:49:19 +08:00
SWivid
3ad3211915 Update F5TTS_Small.yaml 2025-03-25 07:11:35 +08:00
Zhikang Niu
f6726a78cc Update F5TTS_Small.yaml 2025-03-23 22:27:02 +08:00
SWivid
1d0cf2b8ba add device option for infer-cli, patch-1 2025-03-22 17:35:16 +08:00
SWivid
1d82b7928e add device option for infer-cli 2025-03-22 17:30:23 +08:00
SWivid
4ae5347282 pre-commit update and formatting 2025-03-21 23:01:00 +08:00
SWivid
621559cbbe v1.0.7 2025-03-21 14:40:52 +08:00
SWivid
526b09eebd add no_zero_init v1 variant path to SHARED.md 2025-03-21 14:37:14 +08:00
SWivid
9afa80f204 add option in finetune gradio to save non-ema model weight 2025-03-21 13:36:11 +08:00
SWivid
c6b3189bbd v1.0.6 improves docker usage 2025-03-20 22:48:36 +08:00
Yushen CHEN
c87ce39515 Merge pull request #890 from MicahZoltu/patch-1
Improves documentation around docker usage.
2025-03-20 22:45:40 +08:00
Micah Zoltu
10ef27065b Improves documentation around docker usage. 2025-03-20 21:37:48 +08:00
SWivid
f374640f34 Merge branch 'main' of github.com:SWivid/F5-TTS 2025-03-20 13:54:52 +08:00
SWivid
d5f4c88aa4 update issue templates 2025-03-20 13:54:15 +08:00
Yushen CHEN
f968e13b6d Update README.md 2025-03-20 10:15:47 +08:00
SWivid
339b17fed3 update README.md for infer & train 2025-03-20 10:14:22 +08:00
SWivid
79302b694a update README.md for infer & train 2025-03-20 10:03:54 +08:00
SWivid
a1e88c2a9e v1.0.5 update finetune_gradio.py for clearer guidance 2025-03-17 21:50:50 +08:00
SWivid
1ab90505a4 v1.0.4 fix finetune_gradio.py vocab extend with .safetensors ckpt 2025-03-17 16:22:26 +08:00
SWivid
7e4985ca56 v1.0.3 fix api.py 2025-03-17 02:39:20 +08:00
SWivid
f05ceda4cb v1.0.2 fix: torch.utils.checkpoint.checkpoint add use_reentrant=False 2025-03-15 16:34:32 +08:00
Yushen CHEN
2bd39dd813 Merge pull request #859 from ZhikangNiu/main
fix #858 and pass use_reentrant explicitly in checkpoint_activation mode
2025-03-15 16:23:50 +08:00
ZhikangNiu
f017815083 fix #858 and pass use_reentrant explicitly in checkpoint_activation mode 2025-03-15 15:48:47 +08:00
Yushen CHEN
297755fac3 v1.0.1 VRAM usage management #851 2025-03-14 17:31:44 +08:00
Yushen CHEN
d05075205f Merge pull request #851 from niknah/vram-usage
VRAM usage on long texts gradually uses up memory.
2025-03-14 17:25:56 +08:00
Yushen CHEN
8722cf0766 Update utils_infer.py 2025-03-14 17:23:20 +08:00
niknah
48d1a9312e VRAM usage on long texts gradually uses up memory. 2025-03-14 16:53:58 +11:00
Yushen CHEN
128f4e4bf3 Update publish-pypi.yaml 2025-03-13 00:08:36 +08:00
SWivid
2695e9305d v1.0.0 release 2025-03-12 23:47:04 +08:00
SWivid
69909ac167 update README.md 2025-03-12 18:40:07 +08:00
SWivid
79bbde5d76 update README.md add a glance of few demo 2025-03-12 18:37:14 +08:00
SWivid
bf651d541e update README.md for v1.0.0 2025-03-12 17:39:30 +08:00
SWivid
ca6e49adaa 1.0.0 F5-TTS v1 base model with better training and inference performance 2025-03-12 17:23:10 +08:00
78 changed files with 5744 additions and 1125 deletions

View File

@@ -1,6 +1,6 @@
name: "Bug Report"
description: |
Please provide as much details to help address the issue, including logs and screenshots.
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- bug
body:
@@ -15,13 +15,13 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: e.g., CentOS Linux 7, RTX 3090, Python 3.10, torch==2.3.0, cuda 11.8
description: "Provide details including OS, GPU info, Python version, any relevant software or dependencies, and trainer setting."
placeholder: e.g., CentOS Linux 7, 4 * RTX 3090, Python 3.10, torch==2.3.0+cu118, cuda 11.8, config yaml is ...
validations:
required: true
- type: textarea
@@ -39,12 +39,12 @@ body:
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen.
placeholder: Describe in detail what you expected to happen.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened.
placeholder: Describe in detail what actually happened.
validations:
required: false

View File

@@ -15,7 +15,7 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and found not discussion yet.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:

View File

@@ -1,6 +1,6 @@
name: "Help Wanted"
description: |
Please provide as much details to help address the issue, including logs and screenshots.
Please provide as much details to help address the issue more efficiently, including input, output, logs and screenshots.
labels:
- help wanted
body:
@@ -15,36 +15,40 @@ body:
required: true
- label: I have searched for existing issues, including closed ones, and couldn't find a solution.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:
label: Environment Details
description: "Provide details such as OS, Python version, and any relevant software or dependencies."
placeholder: e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
placeholder: |
e.g., macOS 13.5, Python 3.10, torch==2.3.0, Gradio 4.44.1
If training or finetuning related, provide detailed configuration including GPU info and training setup.
validations:
required: true
- type: textarea
attributes:
label: Steps to Reproduce
description: |
Include detailed steps, screenshots, and logs. Use the correct markdown syntax for code blocks.
Include detailed steps, screenshots, and logs. Provide used prompt wav and text. Use the correct markdown syntax for code blocks.
placeholder: |
1. Create a new conda environment.
2. Clone the repository and install as pip package.
3. Run the command: `f5-tts_infer-gradio` with no ref_text provided.
4. Stuck there with the following message... (attach logs and also error msg e.g. after ctrl-c).
5. Prompt & generated wavs are [change suffix to .mp4 to enable direct upload or pack all to .zip].
6. Reference audio's transcription or provided ref_text is `xxx`, and text to generate is `xxx`.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: Describe what you expected to happen, e.g. output a generated audio
placeholder: Describe what you expected to happen in detail, e.g. output a generated audio.
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: Describe what actually happened, failure messages, etc.
placeholder: Describe what actually happened in detail, failure messages, etc.
validations:
required: false

View File

@@ -1,6 +1,6 @@
name: "Question"
description: |
Pure question or inquiry about the project, usage issue goes with "help wanted".
Research question or pure inquiry about the project, usage issue goes with "help wanted".
labels:
- question
body:
@@ -9,13 +9,13 @@ body:
label: Checks
description: "To help us grasp quickly, please confirm the following:"
options:
- label: This template is only for question, not feature requests or bug reports.
- label: This template is only for research question, not usage problems, feature requests or bug reports.
required: true
- label: I have thoroughly reviewed the project documentation and read the related paper(s).
required: true
- label: I have searched for existing issues, including closed ones, no similar questions.
required: true
- label: I confirm that I am using English to submit this report in order to facilitate communication.
- label: I am using English to submit this issue to facilitate community communication.
required: true
- type: textarea
attributes:

66
.github/workflows/publish-pypi.yaml vendored Normal file
View File

@@ -0,0 +1,66 @@
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
# GitHub recommends pinning actions to a commit SHA.
# To get a newer version, you will need to update the SHA.
# You can also reference a tag or branch, but the action may change without warning.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
release-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build release distributions
run: |
# NOTE: put your own distribution build steps here.
python -m pip install build
python -m build
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
pypi-publish:
runs-on: ubuntu-latest
needs:
- release-build
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
# Dedicated environments with protections for publishing are strongly recommended.
environment:
name: pypi
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
# url: https://pypi.org/p/YOURPROJECT
steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v4
with:
name: release-dists
path: dist/
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

2
.gitignore vendored
View File

@@ -7,8 +7,6 @@ ckpts/
wandb/
results/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@@ -1,14 +1,17 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.0
rev: v0.11.2
hooks:
# Run the linter.
- id: ruff
name: ruff linter
args: [--fix]
# Run the formatter.
- id: ruff-format
name: ruff formatter
- id: ruff
name: ruff sorter
args: [--select, I, --fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: v5.0.0
hooks:
- id: check-yaml

View File

@@ -23,4 +23,8 @@ RUN git clone https://github.com/SWivid/F5-TTS.git \
ENV SHELL=/bin/bash
VOLUME /root/.cache/huggingface/hub/
EXPOSE 7860
WORKDIR /workspace/F5-TTS

View File

@@ -18,6 +18,7 @@
### Thanks to all the contributors !
## News
- **2025/03/12**: 🔥 F5-TTS v1 base model with better training and inference performance. [Few demo](https://swivid.github.io/F5-TTS_updates).
- **2024/10/08**: F5-TTS & E2 TTS base models on [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), [🟣 Wisemodel](https://wisemodel.cn/models/SJTU_X-LANCE/F5-TTS_Emilia-ZH-EN).
## Installation
@@ -37,7 +38,7 @@ conda activate f5-tts
> ```bash
> # Install pytorch with your CUDA version, e.g.
> pip install torch==2.3.0+cu118 torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
> ```
</details>
@@ -82,7 +83,7 @@ conda activate f5-tts
> ### 1. As a pip package (if just for inference)
>
> ```bash
> pip install git+https://github.com/SWivid/F5-TTS.git
> pip install f5-tts
> ```
>
> ### 2. Local editable (if also do training, finetuning)
@@ -99,13 +100,34 @@ conda activate f5-tts
# Build from Dockerfile
docker build -t f5tts:v1 .
# Or pull from GitHub Container Registry
docker pull ghcr.io/swivid/f5-tts:main
# Run from GitHub Container Registry
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main
# Quickstart if you want to just run the web interface (not CLI)
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
```
### Runtime
Deployment solution with Triton and TensorRT-LLM.
#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
## Inference
- In order to achieve desired performance, take a moment to read [detailed guidance](src/f5_tts/infer).
- By properly searching the keywords of problem encountered, [issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very helpful.
### 1. Gradio App
Currently supported features:
@@ -158,9 +180,8 @@ volumes:
```bash
# Run with flags
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli \
--model "F5-TTS" \
--ref_audio "ref_audio.wav" \
f5-tts_infer-cli --model F5TTS_v1_Base \
--ref_audio "provide_prompt_wav_path_here.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."
@@ -173,30 +194,29 @@ f5-tts_infer-cli -c custom.toml
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
```
### 3. More instructions
- In order to have better generation results, take a moment to read [detailed guidance](src/f5_tts/infer).
- The [Issues](https://github.com/SWivid/F5-TTS/issues?q=is%3Aissue) are very useful, please try to find the solution by properly searching the keywords of problem encountered. If no answer found, then feel free to open an issue.
## Training
### 1. Gradio App
### 1. With Hugging Face Accelerate
Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
Refer to [training & finetuning guidance](src/f5_tts/train) for best practice.
### 2. With Gradio App
```bash
# Quick start with Gradio web interface
f5-tts_finetune-gradio
```
Read [training & finetuning guidance](src/f5_tts/train) for more instructions.
## [Evaluation](src/f5_tts/eval)
## Development
Use pre-commit to ensure code quality (will run linters and formatters automatically)
Use pre-commit to ensure code quality (will run linters and formatters automatically):
```bash
pip install pre-commit
@@ -209,7 +229,7 @@ When making a pull request, before each commit, run:
pre-commit run --all-files
```
Note: Some model components have linting exceptions for E722 to accommodate tensor notation
Note: Some model components have linting exceptions for E722 to accommodate tensor notation.
## Acknowledgements
@@ -224,6 +244,7 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
- [F5-TTS-ONNX](https://github.com/DakeQQ/F5-TTS-ONNX) ONNX Runtime version by [DakeQQ](https://github.com/DakeQQ)
- [Yuekai Zhang](https://github.com/yuekaizhang) Triton and TensorRT-LLM support ~
## Citation
If our work and codebase is useful for you, please cite as:

View File

@@ -1,10 +1,3 @@
The pretrained model checkpoints can be reached at https://huggingface.co/SWivid/F5-TTS.
Pretrained model ckpts. https://huggingface.co/SWivid/F5-TTS
```
ckpts/
E2TTS_Base/
model_1200000.pt
F5TTS_Base/
model_1200000.pt
```
Scripts will automatically pull model checkpoints from Huggingface, by default to `~/.cache/huggingface/hub/`.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "0.6.2"
version = "1.1.4"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -25,8 +25,8 @@ dependencies = [
"jieba",
"librosa",
"matplotlib",
"nltk",
"numpy<=1.26.4",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
"safetensors",

View File

@@ -6,5 +6,5 @@ target-version = "py310"
dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = true
force-single-line = false
lines-after-imports = 2

View File

@@ -5,9 +5,10 @@ from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
hop_length,
infer_process,
load_model,
load_vocoder,
@@ -15,33 +16,32 @@ from f5_tts.infer.utils_infer import (
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
target_sample_rate,
)
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import seed_everything
class F5TTS:
def __init__(
self,
model_type="F5-TTS",
model="F5TTS_v1_Base",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_name="vocos",
local_path=None,
vocoder_local_path=None,
device=None,
hf_cache_dir=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.seed = -1
self.mel_spec_type = vocoder_name
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
self.ode_method = ode_method
self.use_ema = use_ema
# Set device
if device is not None:
self.device = device
else:
@@ -58,39 +58,29 @@ class F5TTS:
)
# Load models
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
self.load_ema_model(
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
self.vocoder = load_vocoder(
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
)
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
if model_type == "F5-TTS":
if not ckpt_file:
if mel_spec_type == "vocos":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
elif mel_spec_type == "bigvgan":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:
raise ValueError(f"Unknown model type: {model_type}")
# override for previous models
if model == "F5TTS_Base":
if self.mel_spec_type == "vocos":
ckpt_step = 1200000
elif self.mel_spec_type == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
)
self.ema_model = load_model(
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
)
def transcribe(self, ref_audio, language=None):
@@ -102,8 +92,8 @@ class F5TTS:
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)
def export_spectrogram(self, spec, file_spec):
save_spectrogram(spec, file_spec)
def infer(
self,
@@ -121,17 +111,17 @@ class F5TTS:
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
seed=-1,
file_spec=None,
seed=None,
):
if seed == -1:
if seed is None:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
wav, sr, spect = infer_process(
wav, sr, spec = infer_process(
ref_file,
ref_text,
gen_text,
@@ -153,22 +143,22 @@ class F5TTS:
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spect is not None:
self.export_spectrogram(spect, file_spect)
if file_spec is not None:
self.export_spectrogram(spec, file_spec)
return wav, sr, spect
return wav, sr, spec
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spect = f5tts.infer(
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=-1, # random seed = -1
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
print("seed :", f5tts.seed)

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,25 +20,29 @@ optim:
model:
name: E2TTS_Base
tokenizer: pinyin
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 1024
depth: 24
heads: 16
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,25 +20,29 @@ optim:
model:
name: E2TTS_Small
tokenizer: pinyin
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: UNetT
arch:
dim: 768
depth: 20
heads: 12
ff_mult: 4
text_mask_padding: False
pe_attn_head: 1
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,14 +20,17 @@ optim:
model:
name: F5TTS_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
@@ -35,13 +38,14 @@ model:
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates

View File

@@ -1,16 +1,16 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # "frame" or "sample"
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 15
epochs: 11 # only suitable for Emilia, if you want to train it on LibriTTS, set epoch 686
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
@@ -20,14 +20,17 @@ optim:
model:
name: F5TTS_Small
tokenizer: pinyin
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 768
depth: 18
heads: 12
ff_mult: 2
text_dim: 512
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
@@ -35,14 +38,15 @@ model:
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # 'vocos' or 'bigvgan'
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: None # local vocoder path
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | None
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -0,0 +1,53 @@
hydra:
run:
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
datasets:
name: Emilia_ZH_EN # dataset name
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
batch_size_type: frame # frame | sample
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
num_workers: 16
optim:
epochs: 11
learning_rate: 7.5e-5
num_warmup_updates: 20000 # warmup updates
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
max_grad_norm: 1.0 # gradient clipping
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
model:
name: F5TTS_v1_Base # model name
tokenizer: pinyin # tokenizer type
tokenizer_path: null # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
backbone: DiT
arch:
dim: 1024
depth: 22
heads: 16
ff_mult: 2
text_dim: 512
text_mask_padding: True
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000
n_mel_channels: 100
hop_length: 256
win_length: 1024
n_fft: 1024
mel_spec_type: vocos # vocos | bigvgan
vocoder:
is_local: False # use local offline ckpt or not
local_path: null # local vocoder path
ckpts:
logger: wandb # wandb | tensorboard | null
log_samples: True # infer random sample per save checkpoint. wip, normal to fail with extra long samples
save_per_updates: 50000 # save checkpoint per updates
keep_last_n_checkpoints: -1 # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
last_per_updates: 5000 # save last checkpoint per updates
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

View File

@@ -4,6 +4,7 @@
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@@ -1,6 +1,7 @@
import os
import sys
sys.path.append(os.getcwd())
import argparse
@@ -10,6 +11,8 @@ from importlib.resources import files
import torch
import torchaudio
from accelerate import Accelerator
from hydra.utils import get_class
from omegaconf import OmegaConf
from tqdm import tqdm
from f5_tts.eval.utils_eval import (
@@ -18,36 +21,27 @@ from f5_tts.eval.utils_eval import (
get_seedtts_testset_metainfo,
)
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
use_ema = True
target_rms = 0.1
rel_path = str(files("f5_tts").joinpath("../../"))
def main():
# ---------------------- infer setting ---------------------- #
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument("-s", "--seed", default=None, type=int)
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
parser.add_argument("-n", "--expname", required=True)
parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
parser.add_argument("-c", "--ckptstep", default=1250000, type=int)
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
parser.add_argument("-o", "--odemethod", default="euler")
@@ -58,12 +52,8 @@ def main():
args = parser.parse_args()
seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
mel_spec_type = args.mel_spec_type
tokenizer = args.tokenizer
nfe_step = args.nfestep
ode_method = args.odemethod
@@ -77,13 +67,19 @@ def main():
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
if testset == "ls_pc_test_clean":
metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
@@ -111,8 +107,6 @@ def main():
# -------------------------------------------------#
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed=speed,
@@ -139,7 +133,7 @@ def main():
# Model
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
@@ -154,6 +148,10 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
if not os.path.exists(ckpt_path):
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
@@ -200,7 +198,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
print(f"Done batch inference in {timediff / 60:.2f} minutes.")
if __name__ == "__main__":

View File

@@ -1,13 +1,18 @@
#!/bin/bash
# e.g. F5-TTS, 16 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_zh" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "seedtts_test_en" -nfe 16
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_v1_Base" -t "ls_pc_test_clean" -nfe 16
# e.g. Vanilla E2 TTS, 32 NFE
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_zh" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "seedtts_test_en" -o "midpoint" -ss 0
accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -c 1200000 -t "ls_pc_test_clean" -o "midpoint" -ss 0
# e.g. evaluate F5-TTS 16 NFE result on Seed-TTS test-zh
python src/f5_tts/eval/eval_seedtts_testset.py -e wer -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_seedtts_testset.py -e sim -l zh --gen_wav_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0 --gpu_nums 8
python src/f5_tts/eval/eval_utmos.py --audio_dir results/F5TTS_v1_Base_1250000/seedtts_test_zh/seed0_euler_nfe32_vocos_ss-1_cfg2.0_speed1.0
# etc.

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_librispeech_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))
@@ -53,43 +52,37 @@ def main():
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------- WER ---------------------------
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
wer_results = []
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
wer_results.extend(r)
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
with open(wer_result_path, "w") as f:
for line in wer_results:
wers.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
print(f"Results have been saved to {wer_result_path}")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sims = []
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
sims.extend(r)
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
sim = round(sum(sims) / len(sims), 3)
print(f"\nTotal {len(sims)} samples")
print(f"SIM : {sim}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_seed_tts_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))
@@ -52,43 +51,37 @@ def main():
asr_ckpt_dir = "" # auto download to cache dir
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
# --------------------------- WER ---------------------------
# --------------------------------------------------------------------------
full_results = []
metrics = []
if eval_task == "wer":
wer_results = []
wers = []
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_asr_wer, args)
for r in results:
wer_results.extend(r)
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
with open(wer_result_path, "w") as f:
for line in wer_results:
wers.append(line["wer"])
json_line = json.dumps(line, ensure_ascii=False)
f.write(json_line + "\n")
wer = round(np.mean(wers) * 100, 3)
print(f"\nTotal {len(wers)} samples")
print(f"WER : {wer}%")
print(f"Results have been saved to {wer_result_path}")
# --------------------------- SIM ---------------------------
if eval_task == "sim":
sims = []
full_results.extend(r)
elif eval_task == "sim":
with mp.Pool(processes=len(gpus)) as pool:
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
results = pool.map(run_sim, args)
for r in results:
sims.extend(r)
full_results.extend(r)
else:
raise ValueError(f"Unknown metric type: {eval_task}")
sim = round(sum(sims) / len(sims), 3)
print(f"\nTotal {len(sims)} samples")
print(f"SIM : {sim}")
result_path = f"{gen_wav_dir}/_{eval_task}_results.jsonl"
with open(result_path, "w") as f:
for line in full_results:
metrics.append(line[eval_task])
f.write(json.dumps(line, ensure_ascii=False) + "\n")
metric = round(np.mean(metrics), 5)
f.write(f"\n{eval_task.upper()}: {metric}\n")
print(f"\nTotal {len(metrics)} samples")
print(f"{eval_task.upper()}: {metric}")
print(f"{eval_task.upper()} results saved to {result_path}")
if __name__ == "__main__":

View File

@@ -19,25 +19,23 @@ def main():
predictor = predictor.to(device)
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
utmos_results = {}
utmos_score = 0
for audio_path in tqdm(audio_paths, desc="Processing"):
wav_name = audio_path.stem
wav, sr = librosa.load(audio_path, sr=None, mono=True)
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
score = predictor(wav_tensor, sr)
utmos_results[str(wav_name)] = score.item()
utmos_score += score.item()
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
print(f"UTMOS: {avg_score}")
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
utmos_result_path = Path(args.audio_dir) / "_utmos_results.jsonl"
with open(utmos_result_path, "w", encoding="utf-8") as f:
json.dump(utmos_results, f, ensure_ascii=False, indent=4)
for audio_path in tqdm(audio_paths, desc="Processing"):
wav, sr = librosa.load(audio_path, sr=None, mono=True)
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
score = predictor(wav_tensor, sr)
line = {}
line["wav"], line["utmos"] = str(audio_path.stem), score.item()
utmos_score += score.item()
f.write(json.dumps(line, ensure_ascii=False) + "\n")
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
f.write(f"\nUTMOS: {avg_score:.4f}\n")
print(f"Results have been saved to {utmos_result_path}")
print(f"UTMOS: {avg_score:.4f}")
print(f"UTMOS results saved to {utmos_result_path}")
if __name__ == "__main__":

View File

@@ -148,9 +148,9 @@ def get_inference_prompt(
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
assert min_tokens <= total_mel_len <= max_tokens, (
f"Audio {utt} has duration {total_mel_len * hop_length // target_sample_rate}s out of range [{min_secs}, {max_secs}]."
)
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
utts[bucket_i].append(utt)
@@ -389,10 +389,10 @@ def run_sim(args):
model = model.cuda(device)
model.eval()
sims = []
for wav1, wav2, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(wav1)
wav2, sr2 = torchaudio.load(wav2)
sim_results = []
for gen_wav, prompt_wav, truth in tqdm(test_set):
wav1, sr1 = torchaudio.load(gen_wav)
wav2, sr2 = torchaudio.load(prompt_wav)
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
@@ -408,6 +408,11 @@ def run_sim(args):
sim = F.cosine_similarity(emb1, emb2)[0].item()
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
sims.append(sim)
sim_results.append(
{
"wav": Path(gen_wav).stem,
"sim": sim,
}
)
return sims
return sim_results

View File

@@ -4,16 +4,17 @@ The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://h
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
Currently support **30s for a single** generation, which is the **total length** (same logic if `fix_duration`) including both prompt and output audio. However, `infer_cli` and `infer_gradio` will automatically do chunk generation for longer text. Long reference audio will be **clip short to ~12s**.
To avoid possible inference failures, make sure you have seen through the following instructions.
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), check for ffmpeg installation (various tutorials online, blogs, videos, etc.).
- Try turn off use_ema if using an early-stage finetuned checkpoint (which goes just few updates).
- Use reference audio <12s and leave proper silence space (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
- <ins>Uppercased letters</ins> (best with form like K.F.C.) will be uttered letter by letter, and lowercased letters used for common words.
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
## Gradio App
@@ -23,7 +24,7 @@ Currently supported features:
- Basic TTS with Chunk Inference
- Multi-Style / Multi-Speaker Generation
- Voice Chat powered by Qwen2.5-3B-Instruct
- [Custom inference with more language support](src/f5_tts/infer/SHARED.md)
- [Custom inference with more language support](SHARED.md)
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
@@ -68,14 +69,16 @@ Basically you can inference with flags:
```bash
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
f5-tts_infer-cli \
--model "F5-TTS" \
--model F5TTS_v1_Base \
--ref_audio "ref_audio.wav" \
--ref_text "The content, subtitle or transcription of reference audio." \
--gen_text "Some text you want TTS model generate for you."
# Choose Vocoder
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
# Use custom path checkpoint, e.g.
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_v1_Base/model_1250000.safetensors
# More instructions
f5-tts_infer-cli --help
@@ -90,8 +93,8 @@ f5-tts_infer-cli -c custom.toml
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
```toml
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."
@@ -105,8 +108,8 @@ output_dir = "tests"
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
```toml
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""
@@ -126,6 +129,22 @@ ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
## Socket Real-time Service
Real-time voice output with chunk stream:
```bash
# Start socket server
python src/f5_tts/socket_server.py
# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio
# Communicate with socket client
python src/f5_tts/socket_client.py
```
## Speech Editing
To test speech editing capabilities, use the following command:
@@ -134,86 +153,3 @@ To test speech editing capabilities, use the following command:
python src/f5_tts/infer/speech_edit.py
```
## Socket Realtime Client
To communicate with socket server you need to run
```bash
python src/f5_tts/socket_server.py
```
<details>
<summary>Then create client to communicate</summary>
```bash
# If PyAudio not installed
sudo apt-get install portaudio19-dev
pip install pyaudio
```
``` python
# Create the socket_client.py
import socket
import asyncio
import pyaudio
import numpy as np
import logging
import time
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
start_time = time.time()
first_chunk_time = None
async def play_audio_stream():
nonlocal first_chunk_time
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
try:
while True:
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
if not data:
break
if data == b"END":
logger.info("End of audio received.")
break
audio_array = np.frombuffer(data, dtype=np.float32)
stream.write(audio_array.tobytes())
if first_chunk_time is None:
first_chunk_time = time.time()
finally:
stream.stop_stream()
stream.close()
p.terminate()
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
try:
data_to_send = f"{text}".encode("utf-8")
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
await play_audio_stream()
except Exception as e:
logger.error(f"Error in listen_to_F5TTS: {e}")
finally:
client_socket.close()
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))
```
</details>

View File

@@ -16,7 +16,7 @@
<!-- omit in toc -->
### Supported Languages
- [Multilingual](#multilingual)
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
- [English](#english)
- [Finnish](#finnish)
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
@@ -37,7 +37,18 @@
## Multilingual
#### F5-TTS Base @ zh & en @ F5-TTS
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
```bash
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
# A Variant Model: hf://SWivid/F5-TTS/F5TTS_v1_Base_no_zero_init/model_1250000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
```
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
@@ -45,7 +56,7 @@
```bash
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
@@ -64,7 +75,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
@@ -78,7 +89,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
@@ -96,7 +107,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Authors: SPRING Lab, Indian Institute of Technology, Madras
@@ -113,7 +124,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
```bash
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Trained by [Mithril Man](https://github.com/MithrilMan)
@@ -126,12 +137,12 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
#### F5-TTS Base @ ja @ Jmica
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_25498980)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_21999120)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
```bash
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Model: hf://Jmica/F5TTS/JA_21999120/model_21999120.pt
Vocab: hf://Jmica/F5TTS/JA_21999120/vocab_japanese.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
@@ -148,7 +159,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
```bash
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
- Any improvements are welcome

View File

@@ -1,5 +1,5 @@
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/basic/basic_ref_en.wav"
# If an empty "", transcribes the reference audio automatically.
ref_text = "Some call me nature, others call me mother nature."

View File

@@ -1,5 +1,5 @@
# F5-TTS | E2-TTS
model = "F5-TTS"
# F5TTS_v1_Base | E2TTS_Base
model = "F5TTS_v1_Base"
ref_audio = "infer/examples/multi/main.flac"
# If an empty "", transcribes the reference audio automatically.
ref_text = ""

View File

@@ -10,24 +10,25 @@ import numpy as np
import soundfile as sf
import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
cross_fade_duration,
device,
fix_duration,
infer_process,
load_model,
load_vocoder,
mel_spec_type,
nfe_step,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
speed,
sway_sampling_coef,
target_rms,
)
from f5_tts.model import DiT, UNetT
parser = argparse.ArgumentParser(
@@ -50,7 +51,7 @@ parser.add_argument(
"-m",
"--model",
type=str,
help="The model name: F5-TTS | E2-TTS",
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
)
parser.add_argument(
"-mc",
@@ -162,6 +163,11 @@ parser.add_argument(
type=float,
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
)
parser.add_argument(
"--device",
type=str,
help="Specify the device to run on",
)
args = parser.parse_args()
@@ -172,8 +178,7 @@ config = tomli.load(open(args.config, "rb"))
# command-line interface parameters
model = args.model or config.get("model", "F5-TTS")
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
model = args.model or config.get("model", "F5TTS_v1_Base")
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
vocab_file = args.vocab_file or config.get("vocab_file", "")
@@ -203,6 +208,7 @@ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
speed = args.speed or config.get("speed", speed)
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
device = args.device or config.get("device", device)
# patches for pip pkg user
@@ -240,41 +246,42 @@ if vocoder_name == "vocos":
elif vocoder_name == "bigvgan":
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
vocoder = load_vocoder(
vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path, device=device
)
# load TTS model
if model == "F5-TTS":
model_cls = DiT
model_cfg = OmegaConf.load(model_cfg).model.arch
if not ckpt_file: # path not specified, download from repo
if vocoder_name == "vocos":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base"
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif vocoder_name == "bigvgan":
repo_name = "F5-TTS"
exp_name = "F5TTS_Base_bigvgan"
ckpt_step = 1250000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
model_cfg = OmegaConf.load(
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
elif model == "E2-TTS":
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
if not ckpt_file: # path not specified, download from repo
repo_name = "E2-TTS"
exp_name = "E2TTS_Base"
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
if model != "F5TTS_Base":
assert vocoder_name == model_cfg.model.mel_spec.mel_spec_type
# override for previous models
if model == "F5TTS_Base":
if vocoder_name == "vocos":
ckpt_step = 1200000
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
elif vocoder_name == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
print(f"Using {model}...")
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
ema_model = load_model(
model_cls, model_arc, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file, device=device
)
# inference process
@@ -330,6 +337,7 @@ def main():
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=device,
)
generated_audio_segments.append(audio_segment)
@@ -337,7 +345,7 @@ def main():
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,
final_sample_rate,
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,22 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
from f5_tts.model import CFM, DiT, UNetT
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
"cuda"
if torch.cuda.is_available()
@@ -21,44 +28,41 @@ device = (
)
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
win_length = 1024
n_fft = 1024
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
target_rms = 0.1
tokenizer = "pinyin"
dataset_name = "Emilia_ZH_EN"
# ---------------------- infer setting ---------------------- #
seed = None # int | None
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
ckpt_step = 1200000
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
ckpt_step = 1250000
nfe_step = 32 # 16, 32
cfg_strength = 2.0
ode_method = "euler" # euler | midpoint
sway_sampling_coef = -1.0
speed = 1.0
target_rms = 0.1
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
dataset_name = model_cfg.datasets.name
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
hop_length = model_cfg.model.mel_spec.hop_length
win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
output_dir = "tests"
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
# [write the origin_text into a file, e.g. tests/test_edit.txt]
@@ -67,7 +71,7 @@ output_dir = "tests"
# [--language "zho" for Chinese, "eng" for English]
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
origin_text = "Some call me nature, others call me mother nature."
target_text = "Some call me optimist, others call me realist."
parts_to_edit = [
@@ -106,7 +110,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
@@ -152,7 +156,7 @@ for part in parts_to_edit:
dim=-1,
)
offset = end * target_sample_rate
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
audio = audio.to(device)
edit_mask = edit_mask.to(device)

View File

@@ -4,6 +4,7 @@ import os
import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
@@ -14,6 +15,7 @@ from importlib.resources import files
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
@@ -21,16 +23,14 @@ import numpy as np
import torch
import torchaudio
import tqdm
from huggingface_hub import snapshot_download, hf_hub_download
from huggingface_hub import hf_hub_download
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos
from f5_tts.model import CFM
from f5_tts.model.utils import (
get_tokenizer,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}
@@ -128,11 +128,12 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
except ImportError:
print("You need to follow the README to init submodule and change the BigVGAN source code.")
if is_local:
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
# download generator from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
else:
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
vocoder = bigvgan.BigVGAN.from_pretrained(
"nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False, cache_dir=hf_cache_dir
)
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
@@ -149,7 +150,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -186,7 +187,7 @@ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and torch.cuda.get_device_properties(device).major >= 7
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
@@ -289,7 +290,7 @@ def remove_silence_edges(audio, silence_threshold=-42):
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
@@ -301,29 +302,29 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (1)")
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 15000:
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (2)")
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 15000:
aseg = aseg[:15000]
show_info("Audio is over 15s, clipping short. (3)")
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
@@ -383,7 +384,7 @@ def infer_process(
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)
@@ -479,14 +480,15 @@ def infer_batch_process(
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
)
del _
generated = generated.to(torch.float32)
generated = generated.to(torch.float32) # generated mel spectrogram
generated = generated[:, ref_audio_len:, :]
generated_mel_spec = generated.permute(0, 2, 1)
generated = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(generated_mel_spec)
generated_wave = vocoder.decode(generated)
elif mel_spec_type == "bigvgan":
generated_wave = vocoder(generated_mel_spec)
generated_wave = vocoder(generated)
if rms < target_rms:
generated_wave = generated_wave * rms / target_rms
@@ -497,7 +499,9 @@ def infer_batch_process(
for j in range(0, len(generated_wave), chunk_size):
yield generated_wave[j : j + chunk_size], target_sample_rate
else:
yield generated_wave, generated_mel_spec[0].cpu().numpy()
generated_cpu = generated[0].cpu().numpy()
del generated
yield generated_wave, generated_cpu
if streaming:
for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:

View File

@@ -1,9 +1,7 @@
from f5_tts.model.cfm import CFM
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.mmdit import MMDiT
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.cfm import CFM
from f5_tts.model.trainer import Trainer

View File

@@ -4,7 +4,7 @@
### unett.py
- flat unet transformer
- structure same as in e2-tts & voicebox paper except using rotary pos emb
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
- possible abs pos emb & convnextv2 blocks for embedded text before concat
### dit.py
- adaln-zero dit
@@ -14,7 +14,7 @@
- possible long skip connection (first layer to last layer)
### mmdit.py
- sd3 structure
- stable diffusion 3 block structure
- timestep as condition
- left stream: text embedded and applied a abs pos emb
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett

View File

@@ -10,19 +10,18 @@ d - dimension
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -30,10 +29,12 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -49,6 +50,8 @@ class TextEmbedding(nn.Module):
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
@@ -64,7 +67,13 @@ class TextEmbedding(nn.Module):
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
@@ -103,7 +112,10 @@ class DiT(nn.Module):
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
long_skip_connection=False,
checkpoint_activations=False,
):
@@ -112,7 +124,10 @@ class DiT(nn.Module):
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -121,15 +136,40 @@ class DiT(nn.Module):
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
[
DiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
ff_mult=ff_mult,
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
)
for _ in range(depth)
]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in DiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm.linear.weight, 0)
nn.init.constant_(block.attn_norm.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def ckpt_wrapper(self, module):
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
def ckpt_forward(*inputs):
@@ -138,6 +178,9 @@ class DiT(nn.Module):
return ckpt_forward
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
@@ -147,14 +190,25 @@ class DiT(nn.Module):
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
@@ -164,7 +218,8 @@ class DiT(nn.Module):
for block in self.transformer_blocks:
if self.checkpoint_activations:
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
# https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
else:
x = block(x, t, mask=mask, rope=rope)

View File

@@ -11,16 +11,15 @@ from __future__ import annotations
import torch
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -28,18 +27,24 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, out_dim, text_num_embeds):
def __init__(self, out_dim, text_num_embeds, mask_padding=True):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.precompute_max_pos = 1024
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
text = text + 1
if drop_text:
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text)
text = self.text_embed(text) # b nt -> b nt d
# sinus pos emb
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
@@ -49,6 +54,9 @@ class TextEmbedding(nn.Module):
text = text + text_pos_embed
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
return text
@@ -83,13 +91,16 @@ class MMDiT(nn.Module):
dim_head=64,
dropout=0.1,
ff_mult=4,
text_num_embeds=256,
mel_dim=100,
text_num_embeds=256,
text_mask_padding=True,
qk_norm=None,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
self.text_embed = TextEmbedding(dim, text_num_embeds)
self.text_embed = TextEmbedding(dim, text_num_embeds, mask_padding=text_mask_padding)
self.text_cond, self.text_uncond = None, None # text cache
self.audio_embed = AudioEmbedding(mel_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -106,13 +117,33 @@ class MMDiT(nn.Module):
dropout=dropout,
ff_mult=ff_mult,
context_pre_only=i == depth - 1,
qk_norm=qk_norm,
)
for i in range(depth)
]
)
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.initialize_weights()
def initialize_weights(self):
# Zero-out AdaLN layers in MMDiT blocks:
for block in self.transformer_blocks:
nn.init.constant_(block.attn_norm_x.linear.weight, 0)
nn.init.constant_(block.attn_norm_x.linear.bias, 0)
nn.init.constant_(block.attn_norm_c.linear.weight, 0)
nn.init.constant_(block.attn_norm_c.linear.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.norm_out.linear.weight, 0)
nn.init.constant_(self.norm_out.linear.bias, 0)
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
@@ -122,6 +153,7 @@ class MMDiT(nn.Module):
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
):
batch = x.shape[0]
if time.ndim == 0:
@@ -129,7 +161,17 @@ class MMDiT(nn.Module):
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
c = self.text_embed(text, drop_text=drop_text)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, drop_text=True)
c = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, drop_text=False)
c = self.text_cond
else:
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
seq_len = x.shape[1]

View File

@@ -8,24 +8,24 @@ d - dimension
"""
from __future__ import annotations
from typing import Literal
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,
Attention,
AttnProcessor,
ConvNeXtV2Block,
ConvPositionEmbedding,
FeedForward,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -33,10 +33,12 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
@@ -52,6 +54,8 @@ class TextEmbedding(nn.Module):
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
@@ -67,7 +71,13 @@ class TextEmbedding(nn.Module):
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
return text
@@ -106,7 +116,10 @@ class UNetT(nn.Module):
mel_dim=100,
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
):
super().__init__()
@@ -115,7 +128,10 @@ class UNetT(nn.Module):
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
@@ -134,11 +150,12 @@ class UNetT(nn.Module):
attn_norm = RMSNorm(dim)
attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
ff_norm = RMSNorm(dim)
@@ -161,6 +178,9 @@ class UNetT(nn.Module):
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
@@ -170,6 +190,7 @@ class UNetT(nn.Module):
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
@@ -177,7 +198,17 @@ class UNetT(nn.Module):
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
# postfix time t to input x, [b n d] -> [b n+1 d]

View File

@@ -162,13 +162,13 @@ class CFM(nn.Module):
# predict flow
pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
)
return pred + (pred - null_pred) * cfg_strength
@@ -195,6 +195,7 @@ class CFM(nn.Module):
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
self.transformer.clear_cache()
sampled = trajectory[-1]
out = sampled
@@ -269,7 +270,7 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text

View File

@@ -173,7 +173,7 @@ class DynamicBatchSampler(Sampler[list[int]]):
"""
def __init__(
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_residual: bool = False
):
self.sampler = sampler
self.frames_threshold = frames_threshold
@@ -208,12 +208,15 @@ class DynamicBatchSampler(Sampler[list[int]]):
batch = []
batch_frames = 0
if not drop_last and len(batch) > 0:
if not drop_residual and len(batch) > 0:
batches.append(batch)
del indices
self.batches = batches
# Ensure even batches with accelerate BatchSamplerShard cls under frame_per_batch setting
self.drop_last = True
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler."""
self.epoch = epoch

View File

@@ -269,11 +269,36 @@ class ConvNeXtV2Block(nn.Module):
return residual + x
# AdaLayerNormZero
# RMSNorm
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
def forward(self, x):
if self.native_rms_norm:
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
else:
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.to(self.weight.dtype)
x = x * self.weight
return x
# AdaLayerNorm
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
class AdaLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -290,11 +315,11 @@ class AdaLayerNormZero(nn.Module):
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# AdaLayerNorm for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
class AdaLayerNorm_Final(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -341,7 +366,8 @@ class Attention(nn.Module):
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
):
super().__init__()
@@ -362,18 +388,32 @@ class Attention(nn.Module):
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if qk_norm is None:
self.q_norm = None
self.k_norm = None
elif qk_norm == "rms_norm":
self.q_norm = RMSNorm(dim_head, eps=1e-6)
self.k_norm = RMSNorm(dim_head, eps=1e-6)
else:
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
if self.context_dim is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
if qk_norm is None:
self.c_q_norm = None
self.c_k_norm = None
elif qk_norm == "rms_norm":
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
if self.context_dim is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
def forward(
self,
@@ -393,8 +433,11 @@ class Attention(nn.Module):
class AttnProcessor:
def __init__(self):
pass
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
):
self.pe_attn_head = pe_attn_head
def __call__(
self,
@@ -405,19 +448,11 @@ class AttnProcessor:
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -425,6 +460,25 @@ class AttnProcessor:
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
if self.pe_attn_head is not None:
pn = self.pe_attn_head
query[:, :pn, :, :] = apply_rotary_pos_emb(query[:, :pn, :, :], freqs, q_xpos_scale)
key[:, :pn, :, :] = apply_rotary_pos_emb(key[:, :pn, :, :], freqs, k_xpos_scale)
else:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
@@ -470,16 +524,36 @@ class JointAttnProcessor:
batch_size = c.shape[0]
# `sample` projections.
# `sample` projections
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
# `context` projections
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# qk norm
if attn.q_norm is not None:
query = attn.q_norm(query)
if attn.k_norm is not None:
key = attn.k_norm(key)
if attn.c_q_norm is not None:
c_query = attn.c_q_norm(c_query)
if attn.c_k_norm is not None:
c_key = attn.c_k_norm(c_key)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
@@ -492,16 +566,10 @@ class JointAttnProcessor:
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# joint attention
query = torch.cat([query, c_query], dim=2)
key = torch.cat([key, c_key], dim=2)
value = torch.cat([value, c_value], dim=2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
@@ -540,16 +608,17 @@ class JointAttnProcessor:
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
qk_norm=qk_norm,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
@@ -585,26 +654,30 @@ class MMDiTBlock(nn.Module):
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
def __init__(
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
):
super().__init__()
if context_dim is None:
context_dim = dim
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
self.attn_norm_x = AdaLayerNorm(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_dim=context_dim,
context_pre_only=context_pre_only,
qk_norm=qk_norm,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
self.ff_norm_c = nn.LayerNorm(context_dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None

View File

@@ -19,6 +19,7 @@ from f5_tts.model import CFM
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer
@@ -32,7 +33,7 @@ class Trainer:
save_per_updates=1000,
keep_last_n_checkpoints: int = -1, # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
checkpoint_path=None,
batch_size=32,
batch_size_per_gpu=32,
batch_size_type: str = "sample",
max_samples=32,
grad_accumulation_steps=1,
@@ -40,7 +41,7 @@ class Trainer:
noise_scheduler: str | None = None,
duration_predictor: torch.nn.Module | None = None,
logger: str | None = "wandb", # "wandb" | "tensorboard" | None
wandb_project="test_e2-tts",
wandb_project="test_f5-tts",
wandb_run_name="test_run",
wandb_resume_id: str = None,
log_samples: bool = False,
@@ -51,6 +52,7 @@ class Trainer:
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
is_local_vocoder: bool = False, # use local path vocoder
local_vocoder_path: str = "", # local vocoder path
model_cfg_dict: dict = dict(), # training config
):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
@@ -72,21 +74,23 @@ class Trainer:
else:
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config={
if not model_cfg_dict:
model_cfg_dict = {
"epochs": epochs,
"learning_rate": learning_rate,
"num_warmup_updates": num_warmup_updates,
"batch_size": batch_size,
"batch_size_per_gpu": batch_size_per_gpu,
"batch_size_type": batch_size_type,
"max_samples": max_samples,
"grad_accumulation_steps": grad_accumulation_steps,
"max_grad_norm": max_grad_norm,
"gpus": self.accelerator.num_processes,
"noise_scheduler": noise_scheduler,
},
}
model_cfg_dict["gpus"] = self.accelerator.num_processes
self.accelerator.init_trackers(
project_name=wandb_project,
init_kwargs=init_kwargs,
config=model_cfg_dict,
)
elif self.logger == "tensorboard":
@@ -111,9 +115,9 @@ class Trainer:
self.save_per_updates = save_per_updates
self.keep_last_n_checkpoints = keep_last_n_checkpoints
self.last_per_updates = default(last_per_updates, save_per_updates)
self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
self.checkpoint_path = default(checkpoint_path, "ckpts/test_f5-tts")
self.batch_size = batch_size
self.batch_size_per_gpu = batch_size_per_gpu
self.batch_size_type = batch_size_type
self.max_samples = max_samples
self.grad_accumulation_steps = grad_accumulation_steps
@@ -179,7 +183,7 @@ class Trainer:
if (
not exists(self.checkpoint_path)
or not os.path.exists(self.checkpoint_path)
or not any(filename.endswith(".pt") for filename in os.listdir(self.checkpoint_path))
or not any(filename.endswith((".pt", ".safetensors")) for filename in os.listdir(self.checkpoint_path))
):
return 0
@@ -191,7 +195,7 @@ class Trainer:
all_checkpoints = [
f
for f in os.listdir(self.checkpoint_path)
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith((".pt", ".safetensors"))
]
# First try to find regular training checkpoints
@@ -205,8 +209,16 @@ class Trainer:
# If no training checkpoints, use pretrained model
latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
if latest_checkpoint.endswith(".safetensors"): # always a pretrained checkpoint
from safetensors.torch import load_file
checkpoint = load_file(f"{self.checkpoint_path}/{latest_checkpoint}", device="cpu")
checkpoint = {"ema_model_state_dict": checkpoint}
elif latest_checkpoint.endswith(".pt"):
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
checkpoint = torch.load(
f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu"
)
# patch for backward compatibility, 305e3ea
for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
@@ -271,7 +283,7 @@ class Trainer:
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
batch_size=self.batch_size,
batch_size=self.batch_size_per_gpu,
shuffle=True,
generator=generator,
)
@@ -280,10 +292,10 @@ class Trainer:
sampler = SequentialSampler(train_dataset)
batch_sampler = DynamicBatchSampler(
sampler,
self.batch_size,
self.batch_size_per_gpu,
max_samples=self.max_samples,
random_seed=resumable_with_seed, # This enables reproducible shuffling
drop_last=False,
drop_residual=False,
)
train_dataloader = DataLoader(
train_dataset,
@@ -339,7 +351,7 @@ class Trainer:
progress_bar = tqdm(
range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
desc=f"Epoch {epoch+1}/{self.epochs}",
desc=f"Epoch {epoch + 1}/{self.epochs}",
unit="update",
disable=not self.accelerator.is_local_main_process,
initial=progress_bar_initial,
@@ -384,6 +396,9 @@ class Trainer:
self.writer.add_scalar("loss", loss.item(), global_update)
self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
if global_update % self.save_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update)
@@ -417,9 +432,7 @@ class Trainer:
torchaudio.save(
f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
)
if global_update % self.last_per_updates == 0 and self.accelerator.sync_gradients:
self.save_checkpoint(global_update, last=True)
self.model.train()
self.save_checkpoint(global_update, last=True)

View File

@@ -5,11 +5,10 @@ import random
from collections import defaultdict
from importlib.resources import files
import torch
from torch.nn.utils.rnn import pad_sequence
import jieba
from pypinyin import lazy_pinyin, Style
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
# seed everything
@@ -133,11 +132,12 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
# convert char to pinyin
jieba.initialize()
print("Word segmentation module jieba initialized.\n")
def convert_char_to_pinyin(text_list, polyphone=True):
if jieba.dt.initialized is False:
jieba.default_logger.setLevel(50) # CRITICAL
jieba.initialize()
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}

View File

@@ -0,0 +1,3 @@
FROM nvcr.io/nvidia/tritonserver:24.12-py3
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
WORKDIR /workspace

View File

@@ -0,0 +1,69 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
```
### Build Image
Build the docker image from scratch.
```sh
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
```
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Client-Server Mode
```sh
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark using Offline TRT-LLM Mode
```sh
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
```
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -0,0 +1,560 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 authors: Yuekai Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
"""
import argparse
import json
import os
import time
from typing import Dict, List, Union
import datasets
import jieba
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from f5_tts_trtllm import F5TTS
from huggingface_hub import hf_hub_download
from pypinyin import Style, lazy_pinyin
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
torch.manual_seed(0)
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
parser.add_argument(
"--vocab-file",
required=True,
type=str,
help="vocab file",
)
parser.add_argument(
"--model-path",
required=True,
type=str,
help="model path, to load text embedding",
)
parser.add_argument(
"--tllm-model-dir",
required=True,
type=str,
help="tllm model dir",
)
parser.add_argument(
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
parser.add_argument(
"--vocoder",
default="vocos",
type=str,
help="vocoder name",
)
parser.add_argument(
"--vocoder-trt-engine-path",
default=None,
type=str,
help="vocoder trt engine path",
)
parser.add_argument("--enable-warmup", action="store_true")
parser.add_argument("--remove-input-padding", action="store_true")
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
args = parser.parse_args()
return args
def padded_mel_batch(ref_mels, max_seq_len):
padded_ref_mels = []
for mel in ref_mels:
# pad along the last dimension
padded_ref_mel = F.pad(mel, (0, 0, 0, max_seq_len - mel.shape[0]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
return padded_ref_mels
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
ids, ref_mel_list, ref_mel_len_list, estimated_reference_target_mel_len, reference_target_texts_list = (
[],
[],
[],
[],
[],
)
for i, item in enumerate(batch):
item_id, prompt_text, target_text = (
item["id"],
item["prompt_text"],
item["target_text"],
)
ids.append(item_id)
reference_target_texts_list.append(prompt_text + target_text)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_mel = mel_spectrogram(ref_audio, vocoder="vocos", device="cuda")
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel = ref_mel.squeeze()
ref_mel_len = ref_mel.shape[0]
assert ref_mel.shape[1] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel.shape[0] * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
max_seq_len = max(estimated_reference_target_mel_len)
ref_mel_batch = padded_mel_batch(ref_mel_list, max_seq_len)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
# Initialize process group with explicit device IDs
dist.init_process_group(
"nccl",
)
return world_size, local_rank, rank
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: Union[List[str], List[List[str]]],
vocab_char_map: Dict[str, int], # {char: idx}
padding_value=-1,
):
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
# text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return list_idx_tensors
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
if vocoder_name == "vocos":
if vocoder_trt_engine_path is not None:
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
else:
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not implemented yet")
return vocoder
def mel_spectrogram(waveform, vocoder="vocos", device="cuda"):
if vocoder == "vocos":
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=1024,
hop_length=256,
n_mels=100,
power=1,
center=True,
normalized=False,
norm=None,
).to(device)
mel = mel_stft(waveform.to(device))
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vae engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
self.session = Session.from_serialized_engine(engine_buffer)
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
def decode(self, mels):
mels = mels.contiguous()
inputs = {"mel": mels}
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
}
ok = self.session.run(inputs, outputs, self.stream)
assert ok, "Runtime execution failed for vae session"
samples = outputs["waveform"]
return samples
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file)
tllm_model_dir = args.tllm_model_dir
config_file = os.path.join(tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
config, debug_mode=False, tllm_model_dir=tllm_model_dir, model_path=args.model_path, vocab_size=vocab_size
)
elif args.backend_type == "pytorch":
import sys
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
pe_attn_head=1,
text_mask_padding=False,
)
model = load_model(DiT, F5TTS_model_cfg, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
)
dataset = load_dataset(
"yuekai/seed_tts",
split=args.split_name,
trust_remote_code=True,
)
def add_estimated_duration(example):
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
estimated_duration = prompt_audio_len * scale_factor
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
return example
dataset = dataset.map(add_estimated_duration)
dataset = dataset.sort("estimated_duration", reverse=True)
if args.use_perf:
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
dataset = datasets.concatenate_datasets(dataset_list_short)
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
else:
# This would disable shuffling
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
)
total_steps = len(dataset)
if args.enable_warmup:
for batch in dataloader:
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq, ref_mels, ref_mel_lens, total_mel_lens, remove_input_padding=args.remove_input_padding
)
elif args.backend_type == "pytorch":
with torch.inference_mode():
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
total_mel_lens = torch.tensor(total_mel_lens, device=device)
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
decoding_time = 0
vocoder_time = 0
total_duration = 0
if args.use_perf:
torch.cuda.cudart().cudaProfilerStart()
total_decoding_time = time.time()
for batch in dataloader:
if args.use_perf:
torch.cuda.nvtx.range_push("data sample")
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
ref_mels,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
use_perf=args.use_perf,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
text_pad_seq -= 1
text_pad_seq[text_pad_seq == -2] = -1
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=16,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if args.vocoder == "vocos":
if args.use_perf:
torch.cuda.nvtx.range_push("vocoder decode")
generated_wave = vocoder.decode(gen_mel_spec).cpu()
if args.use_perf:
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
# if ref_rms_list[i] < target_rms:
# generated_wave = generated_wave * ref_rms_list[i] / target_rms
rms = torch.sqrt(torch.mean(torch.square(generated_wave)))
if rms < target_rms:
generated_wave = generated_wave * target_rms / rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
total_duration += generated_wave.shape[1] / target_sample_rate
vocoder_time += time.time() - vocoder_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
total_decoding_time = time.time() - total_decoding_time
if rank == 0:
progress_bar.close()
rtf = total_decoding_time / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
s += f"batch size: {args.batch_size}\n"
print(s)
with open(f"{args.output_dir}/rtf.txt", "w") as f:
f.write(s)
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,469 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.
Usage:
num_task=2
# For offline F5-TTS
python3 client_grpc.py \
--server-addr localhost \
--model-name f5_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import os
import time
import types
from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient
from tritonclient.utils import np_to_triton_dtype
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
# write a note, the log is from triton_client.get_inference_statistics(), to better human readability
summary_f.write(
"The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n"
)
summary_f.write("To learn more about the log, please refer to: \n")
summary_f.write("1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n")
summary_f.write("2. https://github.com/triton-inference-server/server/issues/5374 \n\n")
summary_f.write(
"To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n"
)
summary_f.write(
"However, there is a trade-off between the increased queue time and the increased batch size. \n"
)
summary_f.write(
"You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n"
)
summary_f.write(
"See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n"
)
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms / batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms / batch_count / batch_size:.2f} ms \n" # noqa
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, " # noqa
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n" # noqa
)
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Grpc port of the triton server, default is 8001",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--huggingface-dataset",
type=str,
default="yuekai/seed_tts",
help="dataset name in huggingface dataset hub",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="dataset split name, default is 'test'",
)
parser.add_argument(
"--manifest-path",
type=str,
default=None,
help="Path to the manifest dir which includes wav.scp trans.txt files.",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
)
parser.add_argument(
"--num-tasks",
type=int,
default=1,
help="Number of concurrent tasks for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-wer",
action="store_true",
default=False,
help="""True to compute WER.
""",
)
parser.add_argument(
"--log-dir",
type=str,
required=False,
default="./tmp",
help="log directory",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Inference batch_size per request for offline mode.",
)
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
return waveform, target_sample_rate
async def send(
manifest_item_list: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
print(f"manifest_item_list: {manifest_item_list}")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
reference_text, target_text = item["reference_text"], item["target_text"]
estimated_target_duration = duration / len(reference_text) * len(target_text)
if padding_duration:
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
padding_duration
* sample_rate
* ((int(estimated_target_duration + duration) // padding_duration) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput("reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)),
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
input_data_numpy = np.array([reference_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[2].set_data_from_numpy(input_data_numpy)
input_data_numpy = np.array([target_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
sequence_id = 100000000 + i + task_id * 10
start = time.time()
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
audio = response.as_numpy("waveform").reshape(-1)
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, estimated_target_duration))
total_duration += estimated_target_duration
return total_duration, latency_data
def load_manifests(manifest_path):
with open(manifest_path, "r") as f:
manifest_list = []
for line in f:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
# gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
{
"audio_filepath": prompt_wav,
"reference_text": prompt_text,
"target_text": gt_text,
"target_audio_path": utt,
}
)
return manifest_list
def split_data(data, k):
n = len(data)
if n < k:
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
k = n
quotient = n // k
remainder = n % k
result = []
start = 0
for i in range(k):
if i < remainder:
end = start + quotient + 1
else:
end = start + quotient
result.append(data[start:end])
start = end
return result
async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient
if args.reference_audio:
args.num_tasks = 1
args.log_interval = 1
manifest_item_list = [
{
"reference_text": args.reference_text,
"target_text": args.target_text,
"audio_filepath": args.reference_audio,
"target_audio_path": "test",
}
]
elif args.huggingface_dataset:
import datasets
dataset = datasets.load_dataset(
args.huggingface_dataset,
split=args.split_name,
trust_remote_code=True,
)
manifest_item_list = []
for i in range(len(dataset)):
manifest_item_list.append(
{
"audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"],
"target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"],
}
)
else:
manifest_item_list = load_manifests(args.manifest_path)
args.num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, args.num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
tasks = []
start_time = time.time()
for i in range(args.num_tasks):
task = asyncio.create_task(
send(
manifest_item_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
total_duration = 0.0
latency_data = []
for ans in ans_list:
total_duration += ans[0]
latency_data += ans[1]
rtf = elapsed / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
print(s)
if args.manifest_path:
name = Path(args.manifest_path).stem
elif args.split_name:
name = args.split_name
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True)
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,143 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import numpy as np
import requests
import soundfile as sf
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-url",
type=str,
default="localhost:8000",
help="Address of the server",
)
parser.add_argument(
"--reference-audio",
type=str,
default="../../infer/examples/basic/basic_ref_en.wav",
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="Some call me nature, others call me mother nature.",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
samples,
reference_text,
target_text,
sample_rate=16000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
"datatype": "INT32",
"data": lengths.tolist(),
},
{"name": "reference_text", "shape": [1, 1], "datatype": "BYTES", "data": [reference_text]},
{"name": "target_text", "shape": [1, 1], "datatype": "BYTES", "data": [target_text]},
]
}
return data
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
)
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-f5-tts:24.12
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install vocos && rm -rf F5-TTS && git clone https://github.com/SWivid/F5-TTS.git && cd F5-TTS/src/f5_tts/runtime/triton_trtllm/ && bash run.sh 0 4 $MODEL"

View File

@@ -0,0 +1,430 @@
import math
import os
import time
from functools import wraps
from typing import List, Optional
import tensorrt as trt
import tensorrt_llm
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
# Audio tensor case: batch, seq_len, feature_len
# position_ids case: batch, seq_len
assert input_tensor_lengths is not None, "input_tensor_lengths must be provided for 3D input_tensor"
# Initialize a list to collect valid sequences
valid_sequences = []
for i in range(input_tensor.shape[0]):
valid_length = input_tensor_lengths[i]
valid_sequences.append(input_tensor[i, :valid_length])
# Concatenate all valid sequences along the batch dimension
output_tensor = torch.cat(valid_sequences, dim=0).contiguous()
return output_tensor
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
if use_ema:
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
return text_embed_dict
class F5TTS(object):
def __init__(
self,
config,
debug_mode=True,
stream: Optional[torch.cuda.Stream] = None,
tllm_model_dir: Optional[str] = None,
model_path: Optional[str] = None,
vocab_size: Optional[int] = None,
):
self.dtype = config["pretrained_config"]["dtype"]
rank = tensorrt_llm.mpi_rank()
world_size = config["pretrained_config"]["mapping"]["world_size"]
cp_size = config["pretrained_config"]["mapping"]["cp_size"]
tp_size = config["pretrained_config"]["mapping"]["tp_size"]
pp_size = config["pretrained_config"]["mapping"]["pp_size"]
assert pp_size == 1
self.mapping = tensorrt_llm.Mapping(
world_size=world_size, rank=rank, cp_size=cp_size, tp_size=tp_size, pp_size=1, gpus_per_node=1
)
local_rank = rank % self.mapping.gpus_per_node
self.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(self.device)
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_file = os.path.join(tllm_model_dir, f"rank{rank}.engine")
logger.info(f"Loading engine from {engine_file}")
with open(engine_file, "rb") as f:
engine_buffer = f.read()
assert engine_buffer is not None
self.session = Session.from_serialized_engine(engine_buffer)
self.debug_mode = debug_mode
self.inputs = {}
self.outputs = {}
self.buffer_allocated = False
expected_tensor_names = ["noise", "cond", "time", "rope_cos", "rope_sin", "input_lengths", "denoised"]
found_tensor_names = [self.session.engine.get_tensor_name(i) for i in range(self.session.engine.num_io_tensors)]
if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names):
logger.error(
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
)
logger.error(
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
)
logger.error(f"Expected tensor names: {expected_tensor_names}")
logger.error(f"Found tensor names: {found_tensor_names}")
raise RuntimeError("Tensor names in engine are not the same as expected.")
if self.debug_mode:
self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names))
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
freqs = torch.outer(torch.arange(self.max_mel_len, dtype=torch.float32), inv_freq) / self.interpolation_factor
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
emb = time_step[i] * emb_factor
time_expand[:, i, :] = torch.cat((emb.sin(), emb.cos()), dim=-1)
self.time_expand = time_expand.to(self.device)
self.delta_t = torch.cat((delta_t, delta_t), dim=0).contiguous().to(self.device)
def _tensor_dtype(self, name):
# return torch dtype given tensor name for convenience
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
return dtype
def _setup(self, batch_size, seq_len):
for i in range(self.session.engine.num_io_tensors):
name = self.session.engine.get_tensor_name(i)
if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
shape = list(self.session.engine.get_tensor_shape(name))
shape[0] = batch_size
shape[1] = seq_len
self.outputs[name] = torch.empty(shape, dtype=self._tensor_dtype(name), device=self.device)
self.buffer_allocated = True
def cuda_stream_guard(func):
"""Sync external stream and set current stream to the one bound to the session. Reset on exit."""
@wraps(func)
def wrapper(self, *args, **kwargs):
external_stream = torch.cuda.current_stream()
if external_stream != self.stream:
external_stream.synchronize()
torch.cuda.set_stream(self.stream)
ret = func(self, *args, **kwargs)
if external_stream != self.stream:
self.stream.synchronize()
torch.cuda.set_stream(external_stream)
return ret
return wrapper
@cuda_stream_guard
def forward(
self,
noise: torch.Tensor,
cond: torch.Tensor,
time_expand: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
input_lengths: torch.Tensor,
delta_t: torch.Tensor,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("flow matching")
cfg_strength = 2.0
batch_size = noise.shape[0]
half_batch = batch_size // 2
noise_half = noise[:half_batch] # Store the initial half of noise
input_type = str_dtype_to_torch(self.dtype)
# Keep a copy of the initial tensors
cond = cond.to(input_type)
rope_cos = rope_cos.to(input_type)
rope_sin = rope_sin.to(input_type)
input_lengths = input_lengths.to(str_dtype_to_torch("int32"))
# Instead of iteratively updating noise within a single model context,
# we'll do a single forward pass for each iteration with fresh context setup
for i in range(self.nfe_steps):
# Re-setup the buffers for clean execution
self._setup(batch_size, noise.shape[1])
if not self.buffer_allocated:
raise RuntimeError("Buffer not allocated, please call setup first!")
# Re-create combined noises for this iteration
current_noise = torch.cat([noise_half, noise_half], dim=0).to(input_type)
# Get time step for this iteration
current_time = time_expand[:, i].to(input_type)
# Create fresh input dictionary for this iteration
current_inputs = {
"noise": current_noise,
"cond": cond,
"time": current_time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}
# Update inputs and set shapes
self.inputs.clear() # Clear previous inputs
self.inputs.update(**current_inputs)
self.session.set_shapes(self.inputs)
if use_perf:
torch.cuda.nvtx.range_push(f"execute {i}")
ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream)
assert ok, "Failed to execute model"
# self.session.context.execute_async_v3(self.stream.cuda_stream)
if use_perf:
torch.cuda.nvtx.range_pop()
# Process results
t_scale = delta_t[i].unsqueeze(0).to(input_type)
# Extract predictions
pred_cond = self.outputs["denoised"][:half_batch]
pred_uncond = self.outputs["denoised"][half_batch:]
# Apply classifier-free guidance with safeguards
guidance = pred_cond + (pred_cond - pred_uncond) * cfg_strength
# Calculate update for noise
noise_half = noise_half + guidance * t_scale
if use_perf:
torch.cuda.nvtx.range_pop()
return noise_half
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
use_perf: bool = False,
):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
)
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
text_embedding_drop,
),
dim=-1,
)
time_expand = self.time_expand.repeat(2 * batch, 1, 1).contiguous()
# Convert estimated_reference_target_mel_len to tensor
input_lengths = torch.tensor(estimated_reference_target_mel_len, dtype=torch.int32)
# combine above along the batch dimension
inputs = {
"noise": torch.cat((noise, noise), dim=0).contiguous(),
"cond": torch.cat((cat_mel_text, cat_mel_text_drop), dim=0).contiguous(),
"time_expand": time_expand,
"rope_cos": torch.cat((rope_cos, rope_cos), dim=0).contiguous(),
"rope_sin": torch.cat((rope_sin, rope_sin), dim=0).contiguous(),
"input_lengths": torch.cat((input_lengths, input_lengths), dim=0).contiguous(),
"delta_t": self.delta_t,
}
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding")
if remove_input_padding:
max_seq_len = inputs["cond"].shape[1]
inputs["noise"] = remove_tensor_padding(inputs["noise"], inputs["input_lengths"])
inputs["cond"] = remove_tensor_padding(inputs["cond"], inputs["input_lengths"])
# for time_expand, convert from B,D to B,T,D by repeat
inputs["time_expand"] = inputs["time_expand"].unsqueeze(1).repeat(1, max_seq_len, 1, 1)
inputs["time_expand"] = remove_tensor_padding(inputs["time_expand"], inputs["input_lengths"])
inputs["rope_cos"] = remove_tensor_padding(inputs["rope_cos"], inputs["input_lengths"])
inputs["rope_sin"] = remove_tensor_padding(inputs["rope_sin"], inputs["input_lengths"])
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
for key in inputs:
inputs[key] = inputs[key].to(self.device)
if use_perf:
torch.cuda.nvtx.range_pop()
start_time = time.time()
denoised = self.forward(**inputs, use_perf=use_perf)
cost_time = time.time() - start_time
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_push("remove input padding output")
if remove_input_padding:
denoised_list = []
start_idx = 0
for i in range(batch):
denoised_list.append(denoised[start_idx : start_idx + inputs["input_lengths"][i]])
start_idx += inputs["input_lengths"][i]
if use_perf and remove_input_padding:
torch.cuda.nvtx.range_pop()
return denoised_list, cost_time
return denoised, cost_time

View File

@@ -0,0 +1,278 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import jieba
import torch
import torch.nn.functional as F
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
from torch.utils.dlpack import from_dlpack, to_dlpack
def get_tokenizer(vocab_file_path: str):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
with open(vocab_file_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
final_reference_target_texts_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
else:
char_list.append(c)
final_reference_target_texts_list.append(char_list)
return final_reference_target_texts_list
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
class TritonPythonModel:
def initialize(self, args):
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.head_dim = 64
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
parameters[key] = value["string_value"]
self.vocab_char_map, self.vocab_size = get_tokenizer(parameters["vocab_file"])
self.reference_sample_rate = int(parameters["reference_audio_sample_rate"])
self.resampler = torchaudio.transforms.Resample(self.reference_sample_rate, self.target_audio_sample_rate)
self.tllm_model_dir = parameters["tllm_model_dir"]
config_file = os.path.join(self.tllm_model_dir, "config.json")
with open(config_file) as f:
config = json.load(f)
self.model = F5TTS(
config,
debug_mode=False,
tllm_model_dir=self.tllm_model_dir,
model_path=parameters["model_path"],
vocab_size=self.vocab_size,
)
self.vocoder = parameters["vocoder"]
assert self.vocoder in ["vocos", "bigvgan"]
if self.vocoder == "vocos":
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_audio_sample_rate,
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=self.n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(self.device)
self.compute_mel_fn = self.get_vocos_mel_spectrogram
elif self.vocoder == "bigvgan":
self.compute_mel_fn = self.get_bigvgan_mel_spectrogram
def get_vocos_mel_spectrogram(self, waveform):
mel = self.mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel.transpose(1, 2)
def forward_vocoder(self, mel):
mel = mel.to(torch.float32).contiguous().cpu()
input_tensor_0 = pb_utils.Tensor.from_dlpack("mel", to_dlpack(mel))
inference_request = pb_utils.InferenceRequest(
model_name="vocoder", requested_output_names=["waveform"], inputs=[input_tensor_0]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
else:
waveform = pb_utils.get_output_tensor_by_name(inference_response, "waveform")
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def execute(self, requests):
(
reference_text_list,
target_text_list,
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
for request in requests:
wav_tensor = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_lens = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode("utf-8")
reference_text_list.append(reference_text)
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode("utf-8")
target_text_list.append(target_text)
text = reference_text + target_text
reference_target_texts_list.append(text)
wav = from_dlpack(wav_tensor.to_dlpack())
wav_len = from_dlpack(wav_lens.to_dlpack())
wav_len = wav_len.squeeze()
assert wav.shape[0] == 1, "Only support batch size 1 for now."
wav = wav[:, :wav_len]
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
if self.use_perf:
torch.cuda.nvtx.range_push("compute_mel")
mel_features = self.compute_mel_fn(wav)
if self.use_perf:
torch.cuda.nvtx.range_pop()
mel_features_list.append(mel_features)
reference_mel_len.append(mel_features.shape[1])
estimated_reference_target_mel_len.append(
int(
mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
)
)
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
reference_mel_len_tensor = torch.LongTensor(reference_mel_len).to(self.device)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
denoised, cost_time = self.model.sample(
text_pad_sequence,
mel_features,
reference_mel_len_tensor,
estimated_reference_target_mel_len,
remove_input_padding=False,
use_perf=self.use_perf,
)
if self.use_perf:
torch.cuda.nvtx.range_push("vocoder")
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
responses.append(inference_response)
if self.use_perf:
torch.cuda.nvtx.range_pop()
return responses

View File

@@ -0,0 +1,81 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "f5_tts"
backend: "python"
max_batch_size: 4
dynamic_batching {
max_queue_delay_microseconds: 1000
}
parameters [
{
key: "vocab_file"
value: { string_value: "${vocab}"}
},
{
key: "model_path",
value: {string_value:"${model}"}
},
{
key: "tllm_model_dir",
value: {string_value:"${trtllm}"}
},
{
key: "reference_audio_sample_rate",
value: {string_value:"16000"}
},
{
key: "vocoder",
value: {string_value:"${vocoder}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: True
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: True
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,32 @@
name: "vocoder"
backend: "tensorrt"
default_model_filename: "vocoder.plan"
max_batch_size: 4
input [
{
name: "mel"
data_type: TYPE_FP32
dims: [ 100, -1 ]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
dynamic_batching {
preferred_batch_size: [1, 2, 4]
max_queue_delay_microseconds: 1
}
instance_group [
{
count: 1
kind: KIND_GPU
}
]

View File

@@ -0,0 +1,199 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .baichuan.model import BaichuanForCausalLM
from .bert.model import (
BertForQuestionAnswering,
BertForSequenceClassification,
BertModel,
RobertaForQuestionAnswering,
RobertaForSequenceClassification,
RobertaModel,
)
from .bloom.model import BloomForCausalLM, BloomModel
from .chatglm.config import ChatGLMConfig
from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel
from .cogvlm.config import CogVLMConfig
from .cogvlm.model import CogVLMForCausalLM
from .commandr.model import CohereForCausalLM
from .dbrx.config import DbrxConfig
from .dbrx.model import DbrxForCausalLM
from .deepseek_v1.model import DeepseekForCausalLM
from .deepseek_v2.model import DeepseekV2ForCausalLM
from .dit.model import DiT
from .eagle.model import EagleForCausalLM
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .f5tts.model import F5TTS
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
from .gemma.model import GemmaForCausalLM
from .gpt.config import GPTConfig
from .gpt.model import GPTForCausalLM, GPTModel
from .gptj.config import GPTJConfig
from .gptj.model import GPTJForCausalLM, GPTJModel
from .gptneox.model import GPTNeoXForCausalLM, GPTNeoXModel
from .grok.model import GrokForCausalLM
from .llama.config import LLaMAConfig
from .llama.model import LLaMAForCausalLM, LLaMAModel
from .mamba.model import MambaForCausalLM
from .medusa.config import MedusaConfig
from .medusa.model import MedusaForCausalLm
from .mllama.model import MLLaMAModel
from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode
from .mpt.model import MPTForCausalLM, MPTModel
from .nemotron_nas.model import DeciLMForCausalLM
from .opt.model import OPTForCausalLM, OPTModel
from .phi.model import PhiForCausalLM, PhiModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .qwen.model import QWenForCausalLM
from .recurrentgemma.model import RecurrentGemmaForCausalLM
from .redrafter.model import ReDrafterForCausalLM
__all__ = [
"BertModel",
"BertForQuestionAnswering",
"BertForSequenceClassification",
"RobertaModel",
"RobertaForQuestionAnswering",
"RobertaForSequenceClassification",
"BloomModel",
"BloomForCausalLM",
"DiT",
"DeepseekForCausalLM",
"FalconConfig",
"DeepseekV2ForCausalLM",
"FalconForCausalLM",
"FalconModel",
"GPTConfig",
"GPTModel",
"GPTForCausalLM",
"OPTForCausalLM",
"OPTModel",
"LLaMAConfig",
"LLaMAForCausalLM",
"LLaMAModel",
"MedusaConfig",
"MedusaForCausalLm",
"ReDrafterForCausalLM",
"GPTJConfig",
"GPTJModel",
"GPTJForCausalLM",
"GPTNeoXModel",
"GPTNeoXForCausalLM",
"PhiModel",
"PhiConfig",
"Phi3Model",
"Phi3Config",
"PhiForCausalLM",
"Phi3ForCausalLM",
"ChatGLMConfig",
"ChatGLMForCausalLM",
"ChatGLMModel",
"BaichuanForCausalLM",
"QWenConfigQWenForCausalLM",
"QWenModel",
"EncoderModel",
"DecoderModel",
"PretrainedConfig",
"PretrainedModel",
"WhisperEncoder",
"MambaForCausalLM",
"MambaConfig",
"MPTForCausalLM",
"MPTModel",
"SkyworkForCausalLM",
"GemmaConfig",
"GemmaForCausalLM",
"DbrxConfig",
"DbrxForCausalLM",
"RecurrentGemmaForCausalLM",
"CogVLMConfig",
"CogVLMForCausalLM",
"EagleForCausalLM",
"SpeculativeDecodingMode",
"CohereForCausalLM",
"MLLaMAModel",
"F5TTS",
]
MODEL_MAP = {
"GPT2LMHeadModel": GPTForCausalLM,
"GPT2LMHeadCustomModel": GPTForCausalLM,
"GPTBigCodeForCausalLM": GPTForCausalLM,
"Starcoder2ForCausalLM": GPTForCausalLM,
"FuyuForCausalLM": GPTForCausalLM,
"Kosmos2ForConditionalGeneration": GPTForCausalLM,
"JAISLMHeadModel": GPTForCausalLM,
"GPTForCausalLM": GPTForCausalLM,
"NemotronForCausalLM": GPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"BloomForCausalLM": BloomForCausalLM,
"RWForCausalLM": FalconForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"Phi3ForCausalLM": Phi3ForCausalLM,
"Phi3VForCausalLM": Phi3ForCausalLM,
"Phi3SmallForCausalLM": Phi3ForCausalLM,
"PhiMoEForCausalLM": Phi3ForCausalLM,
"MambaForCausalLM": MambaForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"GLMModel": ChatGLMForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForCausalLM": ChatGLMForCausalLM,
"LlamaForCausalLM": LLaMAForCausalLM,
"ExaoneForCausalLM": LLaMAForCausalLM,
"MistralForCausalLM": LLaMAForCausalLM,
"MixtralForCausalLM": LLaMAForCausalLM,
"ArcticForCausalLM": LLaMAForCausalLM,
"Grok1ModelForCausalLM": GrokForCausalLM,
"InternLMForCausalLM": LLaMAForCausalLM,
"InternLM2ForCausalLM": LLaMAForCausalLM,
"MedusaForCausalLM": MedusaForCausalLm,
"ReDrafterForCausalLM": ReDrafterForCausalLM,
"BaichuanForCausalLM": BaichuanForCausalLM,
"BaiChuanForCausalLM": BaichuanForCausalLM,
"SkyworkForCausalLM": LLaMAForCausalLM,
GEMMA_ARCHITECTURE: GemmaForCausalLM,
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
"QWenLMHeadModel": QWenForCausalLM,
"QWenForCausalLM": QWenForCausalLM,
"Qwen2ForCausalLM": QWenForCausalLM,
"Qwen2MoeForCausalLM": QWenForCausalLM,
"Qwen2ForSequenceClassification": QWenForCausalLM,
"Qwen2VLForConditionalGeneration": QWenForCausalLM,
"WhisperEncoder": WhisperEncoder,
"EncoderModel": EncoderModel,
"DecoderModel": DecoderModel,
"DbrxForCausalLM": DbrxForCausalLM,
"RecurrentGemmaForCausalLM": RecurrentGemmaForCausalLM,
"CogVLMForCausalLM": CogVLMForCausalLM,
"DiT": DiT,
"DeepseekForCausalLM": DeepseekForCausalLM,
"DeciLMForCausalLM": DeciLMForCausalLM,
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
"EagleForCausalLM": EagleForCausalLM,
"CohereForCausalLM": CohereForCausalLM,
"MllamaForConditionalGeneration": MLLaMAModel,
"BertForQuestionAnswering": BertForQuestionAnswering,
"BertForSequenceClassification": BertForSequenceClassification,
"BertModel": BertModel,
"RobertaModel": RobertaModel,
"RobertaForQuestionAnswering": RobertaForQuestionAnswering,
"RobertaForSequenceClassification": RobertaForSequenceClassification,
"F5TTS": F5TTS,
}

View File

@@ -0,0 +1,222 @@
from __future__ import annotations
import os
import sys
from collections import OrderedDict
import tensorrt as trt
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt
from ...functional import Tensor, concat
from ...layers import Linear
from ...module import Module, ModuleList
from ...plugin import current_all_reduce_helper
from ..modeling_utils import PretrainedConfig, PretrainedModel
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)
sys.path.append(parent_dir)
class InputEmbedding(Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x, cond):
x = self.proj(concat([x, cond], dim=-1))
return self.conv_pos_embed(x) + x
class F5TTS(PretrainedModel):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.dtype = str_dtype_to_trt(config.dtype)
self.time_embed = TimestepEmbedding(config.hidden_size)
self.input_embed = InputEmbedding(config.mel_dim, config.text_dim, config.hidden_size)
self.dim = config.hidden_size
self.depth = config.num_hidden_layers
self.transformer_blocks = ModuleList(
[
DiTBlock(
dim=self.dim,
heads=config.num_attention_heads,
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
)
for _ in range(self.depth)
]
)
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
self.proj_out = Linear(config.hidden_size, config.mel_dim)
def forward(
self,
noise, # nosied input audio
cond, # masked cond audio
time, # time step
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
):
t = self.time_embed(time)
x = self.input_embed(noise, cond)
for block in self.transformer_blocks:
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
denoise = self.proj_out(self.norm_out(x, t))
denoise.mark_output("denoised", self.dtype)
return denoise
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
if default_net().plugin_config.remove_input_padding:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, mel_size],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, concat_feature_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, head_dim],
dim_range=OrderedDict(
[
("num_frames", [num_frames_range]),
("head_dim", [head_dim]),
]
),
)
else:
noise = Tensor(
name="noise",
dtype=self.dtype,
shape=[-1, -1, mel_size],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("n_mels", [mel_size]),
]
),
)
cond = Tensor(
name="cond",
dtype=self.dtype,
shape=[-1, -1, concat_feature_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("embeded_length", [concat_feature_dim]),
]
),
)
time = Tensor(
name="time",
dtype=self.dtype,
shape=[-1, freq_embed_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("freq_dim", [freq_embed_dim]),
]
),
)
rope_cos = Tensor(
name="rope_cos",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
rope_sin = Tensor(
name="rope_sin",
dtype=self.dtype,
shape=[-1, -1, head_dim],
dim_range=OrderedDict(
[
("batch_size", [batch_size_range]),
("max_duratuion", [[100, max_seq_len // 2, max_seq_len]]),
("head_dim", [head_dim]),
]
),
)
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [batch_size_range])]),
)
return {
"noise": noise,
"cond": cond,
"time": time,
"rope_cos": rope_cos,
"rope_sin": rope_sin,
"input_lengths": input_lengths,
}

View File

@@ -0,0 +1,412 @@
from __future__ import annotations
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
from ...functional import (
Tensor,
bert_attention,
cast,
chunk,
concat,
constant,
expand,
expand_dims,
expand_dims_like,
expand_mask,
gelu,
matmul,
permute,
shape,
silu,
slice,
softmax,
squeeze,
unsqueeze,
view,
)
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
from ...module import Module
class FeedForward(Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.project_in = Linear(dim, inner_dim)
self.ff = Linear(inner_dim, dim_out)
def forward(self, x):
return self.ff(gelu(self.project_in(x)))
class AdaLayerNormZero(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 6)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(emb, 6, dim=1)
x = self.norm(x)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = x * (ones + scale_msa) + shift_msa
else:
x = x * (ones + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZero_Final(Module):
def __init__(self, dim):
super().__init__()
self.linear = Linear(dim, dim * 2)
self.norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(silu(emb))
scale, shift = chunk(emb, 2, dim=1)
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
x = self.norm(x) * (ones + scale) + shift
else:
x = self.norm(x) * unsqueeze((ones + scale), 1)
x = x + unsqueeze(shift, 1)
return x
class ConvPositionEmbedding(Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d1 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
self.mish = Mish()
def forward(self, x, mask=None): # noqa: F722
if default_net().plugin_config.remove_input_padding:
x = unsqueeze(x, 0)
x = permute(x, [0, 2, 1])
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
out = permute(x, [0, 2, 1])
if default_net().plugin_config.remove_input_padding:
out = squeeze(out, 0)
return out
class Attention(Module):
def __init__(
self,
processor: AttnProcessor,
dim: int,
heads: int = 16,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim # hidden_size
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.attention_head_size = dim_head
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.tp_size = 1
self.num_attention_heads = heads // self.tp_size
self.num_attention_kv_heads = heads // self.tp_size # 8
self.dtype = str_dtype_to_trt("float32")
self.attention_hidden_size = self.attention_head_size * self.num_attention_heads
self.to_q = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_k = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
self.to_v = ColumnLinear(
dim,
self.tp_size * self.num_attention_heads * self.attention_head_size,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_dim is not None:
self.to_k_c = Linear(context_dim, self.inner_dim)
self.to_v_c = Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = Linear(context_dim, self.inner_dim)
self.to_out = RowLinear(
self.tp_size * self.num_attention_heads * self.attention_head_size,
dim,
bias=True,
dtype=self.dtype,
tp_group=None,
tp_size=self.tp_size,
)
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = Linear(self.inner_dim, dim)
def forward(
self,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
c=None, # context c
scale=1.0,
rope=None,
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, input_lengths=input_lengths, scale=scale, rope=rope, c_rope=c_rope)
else:
return self.processor(
self, x, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale
)
def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
shape_tensor = concat(
[shape(tensor, i) / 2 if i == (tensor.ndim() - 1) else shape(tensor, i) for i in range(tensor.ndim())]
)
if default_net().plugin_config.remove_input_padding:
assert tensor.ndim() == 2
x1 = slice(tensor, [0, 0], shape_tensor, [1, 2])
x2 = slice(tensor, [0, 1], shape_tensor, [1, 2])
x1 = expand_dims(x1, 2)
x2 = expand_dims(x2, 2)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 2)
out = view(x, concat([shape(x, 0), shape(x, 1) * 2]))
else:
assert tensor.ndim() == 3
x1 = slice(tensor, [0, 0, 0], shape_tensor, [1, 1, 2])
x2 = slice(tensor, [0, 0, 1], shape_tensor, [1, 1, 2])
x1 = expand_dims(x1, 3)
x2 = expand_dims(x2, 3)
zero = constant(np.ascontiguousarray(np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype))))
x2 = zero - x2
x = concat([x2, x1], 3)
out = view(x, concat([shape(x, 0), shape(x, 1), shape(x, 2) * 2]))
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn,
x, # noised input x
rope_cos,
rope_sin,
input_lengths,
scale=1.0,
rope=None,
) -> torch.FloatTensor:
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
# attention
inner_dim = key.shape[-1]
norm_factor = math.sqrt(attn.attention_head_size)
q_scaling = 1.0 / norm_factor
mask = None
if not default_net().plugin_config.remove_input_padding:
N = shape(x, 1)
B = shape(x, 0)
seq_len_2d = concat([1, N])
max_position_embeddings = 4096
# create position ids
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
mask = tmp_position_ids < tmp_input_lengths # BxL
mask = mask.cast("int32")
if default_net().plugin_config.bert_attention_plugin:
qkv = concat([query, key, value], dim=-1)
# TRT plugin mode
assert input_lengths is not None
if default_net().plugin_config.remove_input_padding:
qkv = qkv.view(concat([-1, 3 * inner_dim]))
max_input_length = constant(
np.zeros(
[
2048,
],
dtype=np.int32,
)
)
else:
max_input_length = None
context = bert_attention(
qkv,
input_lengths,
attn.num_attention_heads,
attn.attention_head_size,
q_scaling=q_scaling,
max_input_length=max_input_length,
)
else:
assert not default_net().plugin_config.remove_input_padding
def transpose_for_scores(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.transpose(1, 2)
return y
def transpose_for_scores_k(x):
new_x_shape = concat([shape(x, 0), shape(x, 1), attn.num_attention_heads, attn.attention_head_size])
y = x.view(new_x_shape)
y = y.permute([0, 2, 3, 1])
return y
query = transpose_for_scores(query)
key = transpose_for_scores_k(key)
value = transpose_for_scores(value)
attention_scores = matmul(query, key, use_fp32_acc=False)
if mask is not None:
attention_mask = expand_mask(mask, shape(query, 2))
attention_mask = cast(attention_mask, attention_scores.dtype)
attention_scores = attention_scores + attention_mask
attention_probs = softmax(attention_scores, dim=-1)
context = matmul(attention_probs, value, use_fp32_acc=False).transpose(1, 2)
context = context.view(concat([shape(context, 0), shape(context, 1), attn.attention_hidden_size]))
context = attn.to_out(context)
if mask is not None:
mask = mask.view(concat([shape(mask, 0), shape(mask, 1), 1]))
mask = expand_dims_like(mask, context)
mask = cast(mask, context.dtype)
context = context * mask
return context
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
def forward(
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
# norm ----> (2,1226,1024)
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
# process attention output for input x
if default_net().plugin_config.remove_input_padding:
x = x + gate_msa * attn_output
else:
x = x + unsqueeze(gate_msa, 1) * attn_output
ones = constant(np.ones(1, dtype=np.float32)).cast(x.dtype)
if default_net().plugin_config.remove_input_padding:
norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
else:
norm = self.ff_norm(x) * (ones + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
# norm = self.ff_norm(x) * (ones + scale_mlp) + shift_mlp
ff_output = self.ff(norm)
if default_net().plugin_config.remove_input_padding:
x = x + gate_mlp * ff_output
else:
x = x + unsqueeze(gate_mlp, 1) * ff_output
return x
class TimestepEmbedding(Module):
def __init__(self, dim, freq_embed_dim=256, dtype=None):
super().__init__()
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.mlp1 = Linear(freq_embed_dim, dim, bias=True, dtype=dtype)
self.mlp2 = Linear(dim, dim, bias=True, dtype=dtype)
def forward(self, timestep):
t_freq = self.mlp1(timestep)
t_freq = silu(t_freq)
t_emb = self.mlp2(t_freq)
return t_emb

View File

@@ -0,0 +1,24 @@
accelerate>=0.33.0
bitsandbytes>0.37.0
cached_path
click
datasets
ema_pytorch>=0.5.2
gradio>=3.45.2
hydra-core>=1.3.0
jieba
librosa
matplotlib
numpy<=1.26.4
pydub
pypinyin
safetensors
soundfile
tomli
torch>=2.0.0
# torchaudio>=2.0.0
torchdiffeq
tqdm>=4.65.0
transformers
x_transformers>=1.31.14
packaging>=24.2

View File

@@ -0,0 +1,110 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
if [ -z "$model" ]; then
echo "Model is none, using default model F5TTS_Base"
model=F5TTS_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Testing http client"
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "TRT-LLM: offline decoding benchmark test"
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
pip install -r requirements-pytorch.txt
batch_size=1
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./log_benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
ln -s model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py ./
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt \
--vocab-file $F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt \
--backend-type $backend_type \
--tllm-model-dir $F5_TTS_TRT_LLM_ENGINE_PATH || exit 1
fi

View File

@@ -0,0 +1,248 @@
# Modified from https://github.com/echocatzh/conv-stft/blob/master/conv_stft/conv_stft.py
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
# Copyright (c) 2020 Shimin Zhang
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window
support_clp_op = None
if th.__version__ >= "1.7.0":
from torch.fft import rfft as fft
support_clp_op = True
else:
from torch import rfft as fft
class STFT(th.nn.Module):
def __init__(
self,
win_len=1024,
win_hop=512,
fft_len=1024,
enframe_mode="continue",
win_type="hann",
win_sqrt=False,
pad_center=True,
):
"""
Implement of STFT using 1D convolution and 1D transpose convolutions.
Implement of framing the signal in 2 ways, `break` and `continue`.
`break` method is a kaldi-like framing.
`continue` method is a librosa-like framing.
More information about `perfect reconstruction`:
1. https://ww2.mathworks.cn/help/signal/ref/stft.html
2. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
Args:
win_len (int): Number of points in one frame. Defaults to 1024.
win_hop (int): Number of framing stride. Defaults to 512.
fft_len (int): Number of DFT points. Defaults to 1024.
enframe_mode (str, optional): `break` and `continue`. Defaults to 'continue'.
win_type (str, optional): The type of window to create. Defaults to 'hann'.
win_sqrt (bool, optional): using square root window. Defaults to True.
pad_center (bool, optional): `perfect reconstruction` opts. Defaults to True.
"""
super(STFT, self).__init__()
assert enframe_mode in ["break", "continue"]
assert fft_len >= win_len
self.win_len = win_len
self.win_hop = win_hop
self.fft_len = fft_len
self.mode = enframe_mode
self.win_type = win_type
self.win_sqrt = win_sqrt
self.pad_center = pad_center
self.pad_amount = self.fft_len // 2
en_k, fft_k, ifft_k, ola_k = self.__init_kernel__()
self.register_buffer("en_k", en_k)
self.register_buffer("fft_k", fft_k)
self.register_buffer("ifft_k", ifft_k)
self.register_buffer("ola_k", ola_k)
def __init_kernel__(self):
"""
Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
** enframe_kernel: Using conv1d layer and identity matrix.
** fft_kernel: Using linear layer for matrix multiplication. In fact,
enframe_kernel and fft_kernel can be combined, But for the sake of
readability, I took the two apart.
** ifft_kernel, pinv of fft_kernel.
** overlap-add kernel, just like enframe_kernel, but transposed.
Returns:
tuple: four kernels.
"""
enframed_kernel = th.eye(self.fft_len)[:, None, :]
if support_clp_op:
tmp = fft(th.eye(self.fft_len))
fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
else:
fft_kernel = fft(th.eye(self.fft_len), 1)
if self.mode == "break":
enframed_kernel = th.eye(self.win_len)[:, None, :]
fft_kernel = fft_kernel[: self.win_len]
fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
window = get_window(self.win_type, self.win_len)
self.perfect_reconstruct = check_COLA(window, self.win_len, self.win_len - self.win_hop)
window = th.FloatTensor(window)
if self.mode == "continue":
left_pad = (self.fft_len - self.win_len) // 2
right_pad = left_pad + (self.fft_len - self.win_len) % 2
window = F.pad(window, (left_pad, right_pad))
if self.win_sqrt:
self.padded_window = window
window = th.sqrt(window)
else:
self.padded_window = window**2
fft_kernel = fft_kernel.T * window
ifft_kernel = ifft_kernel * window
ola_kernel = th.eye(self.fft_len)[: self.win_len, None, :]
if self.mode == "continue":
ola_kernel = th.eye(self.fft_len)[:, None, : self.fft_len]
return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
def is_perfect(self):
"""
Whether the parameters win_len, win_hop and win_sqrt
obey constants overlap-add(COLA)
Returns:
bool: Return true if parameters obey COLA.
"""
return self.perfect_reconstruct and self.pad_center
def transform(self, inputs, return_type="complex"):
"""Take input data (audio) to STFT domain.
Args:
inputs (tensor): Tensor of floats, with shape (num_batch, num_samples)
return_type (str, optional): return (mag, phase) when `magphase`,
return (real, imag) when `realimag` and complex(real, imag) when `complex`.
Defaults to 'complex'.
Returns:
tuple: (mag, phase) when `magphase`, return (real, imag) when
`realimag`. Defaults to 'complex', each elements with shape
[num_batch, num_frequencies, num_frames]
"""
assert return_type in ["magphase", "realimag", "complex"]
if inputs.dim() == 2:
inputs = th.unsqueeze(inputs, 1)
self.num_samples = inputs.size(-1)
if self.pad_center:
inputs = F.pad(inputs, (self.pad_amount, self.pad_amount), mode="reflect")
enframe_inputs = F.conv1d(inputs, self.en_k, stride=self.win_hop)
outputs = th.transpose(enframe_inputs, 1, 2)
outputs = F.linear(outputs, self.fft_k)
outputs = th.transpose(outputs, 1, 2)
dim = self.fft_len // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
if return_type == "realimag":
return real, imag
elif return_type == "complex":
assert support_clp_op
return th.complex(real, imag)
else:
mags = th.sqrt(real**2 + imag**2)
phase = th.atan2(imag, real)
return mags, phase
def inverse(self, input1, input2=None, input_type="magphase"):
"""Call the inverse STFT (iSTFT), given tensors produced
by the `transform` function.
Args:
input1 (tensors): Magnitude/Real-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input2 (tensors): Phase/Imag-part of STFT with shape
[num_batch, num_frequencies, num_frames]
input_type (str, optional): Mathematical meaning of input tensor's.
Defaults to 'magphase'.
Returns:
tensors: Reconstructed audio given magnitude and phase. Of
shape [num_batch, num_samples]
"""
assert input_type in ["magphase", "realimag"]
if input_type == "realimag":
real, imag = None, None
if support_clp_op and th.is_complex(input1):
real, imag = input1.real, input1.imag
else:
real, imag = input1, input2
else:
real = input1 * th.cos(input2)
imag = input1 * th.sin(input2)
inputs = th.cat([real, imag], dim=1)
outputs = F.conv_transpose1d(inputs, self.ifft_k, stride=self.win_hop)
t = (self.padded_window[None, :, None]).repeat(1, 1, inputs.size(-1))
t = t.to(inputs.device)
coff = F.conv_transpose1d(t, self.ola_k, stride=self.win_hop)
num_frames = input1.size(-1)
num_samples = num_frames * self.win_hop
rm_start, rm_end = self.pad_amount, self.pad_amount + num_samples
outputs = outputs[..., rm_start:rm_end]
coff = coff[..., rm_start:rm_end]
coffidx = th.where(coff > 1e-8)
outputs[coffidx] = outputs[coffidx] / (coff[coffidx])
return outputs.squeeze(dim=1)
def forward(self, inputs):
"""Take input data (audio) to STFT domain and then back to audio.
Args:
inputs (tensor): Tensor of floats, with shape [num_batch, num_samples]
Returns:
tensor: Reconstructed audio given magnitude and phase.
Of shape [num_batch, num_samples]
"""
mag, phase = self.transform(inputs)
rec_wav = self.inverse(mag, phase)
return rec_wav

View File

@@ -0,0 +1,358 @@
import argparse
import json
import os
import re
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
import safetensors.torch
import torch
from tensorrt_llm import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
def split_q_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=1)
return split_v.contiguous()
def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
split_v = split(v, tensor_parallel, rank, dim=0)
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
choices=[
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
parser.add_argument("--dtype", type=str, default="float16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--fp8_linear", action="store_true", help="Whether use FP8 for linear layers")
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
args = parser.parse_args()
return args
def convert_timm_dit(args, mapping, dtype="float32"):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
weights = dict()
for name, param in model_params.items():
if name == "input_embed.conv_pos_embed.conv1d.0.weight" or name == "input_embed.conv_pos_embed.conv1d.2.weight":
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype).unsqueeze(-1)
else:
weights[get_trtllm_name(name)] = param.contiguous().to(torch_dtype)
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
scale_factor = math.pow(64, -0.25)
for k, v in weights.items():
if re.match("^transformer_blocks.*.attn.to_k.weight$", k):
weights[k] *= scale_factor
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_k.bias$", k):
weights[k] *= scale_factor
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_q.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_q.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
weights[k] *= scale_factor
elif re.match("^transformer_blocks.*.attn.to_v.weight$", k):
weights[k] = split_q_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_v.bias$", k):
weights[k] = split_q_bias_tp(v, args.num_heads, args.hidden_size, tensor_parallel, mapping.tp_rank)
elif re.match("^transformer_blocks.*.attn.to_out.weight$", k):
weights[k] = split_matrix_tp(v, tensor_parallel, mapping.tp_rank, dim=1)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Weights loaded. Total time: {t}")
return weights
def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"ff_mult": 2,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
"tp_size": args.tp_size,
"pp_size": args.pp_size,
},
}
if args.fp8_linear:
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
def covert_and_save(args, rank):
if rank == 0:
save_config(args)
mapping = Mapping(
world_size=args.cp_size * args.tp_size * args.pp_size,
rank=rank,
cp_size=args.cp_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(exceptions) == 0, "Checkpoint conversion failed, please check error log."
def main():
args = parse_arguments()
world_size = args.cp_size * args.tp_size * args.pp_size
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
return
print("start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Total time of converting checkpoints: {t}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,138 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torch.nn as nn
from conv_stft import STFT
from huggingface_hub import hf_hub_download
from vocos import Vocos
opset_version = 17
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--vocoder",
type=str,
default="vocos",
choices=["vocos", "bigvgan"],
help="Vocoder to export",
)
parser.add_argument(
"--output-path",
type=str,
default="./vocos_vocoder.onnx",
help="Output path",
)
return parser.parse_args()
class ISTFTHead(nn.Module):
def __init__(self, n_fft: int, hop_length: int):
super().__init__()
self.out = None
self.stft = STFT(fft_len=n_fft, win_hop=hop_length, win_len=n_fft)
def forward(self, x: torch.Tensor):
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(mag, max=1e2)
real = mag * torch.cos(p)
imag = mag * torch.sin(p)
audio = self.stft.inverse(input1=real, input2=imag, input_type="realimag")
return audio
class VocosVocoder(nn.Module):
def __init__(self, vocos_vocoder):
super(VocosVocoder, self).__init__()
self.vocos_vocoder = vocos_vocoder
istft_head_out = self.vocos_vocoder.head.out
n_fft = self.vocos_vocoder.head.istft.n_fft
hop_length = self.vocos_vocoder.head.istft.hop_length
istft_head_for_export = ISTFTHead(n_fft, hop_length)
istft_head_for_export.out = istft_head_out
self.vocos_vocoder.head = istft_head_for_export
def forward(self, mel):
waveform = self.vocos_vocoder.decode(mel)
return waveform
def export_VocosVocoder(vocos_vocoder, output_path, verbose):
vocos_vocoder = VocosVocoder(vocos_vocoder).cuda()
vocos_vocoder.eval()
dummy_batch_size = 8
dummy_input_length = 500
dummy_mel = torch.randn(dummy_batch_size, 100, dummy_input_length).cuda()
with torch.no_grad():
dummy_waveform = vocos_vocoder(mel=dummy_mel)
print(dummy_waveform.shape)
dummy_input = dummy_mel
torch.onnx.export(
vocos_vocoder,
dummy_input,
output_path,
opset_version=opset_version,
do_constant_folding=True,
input_names=["mel"],
output_names=["waveform"],
dynamic_axes={
"mel": {0: "batch_size", 2: "input_length"},
"waveform": {0: "batch_size", 1: "output_length"},
},
verbose=verbose,
)
print("Exported to {}".format(output_path))
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device="cpu", hf_cache_dir=None):
if vocoder_name == "vocos":
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not supported yet")
vocoder.remove_weight_norm()
vocoder = vocoder.eval().to(device)
return vocoder
if __name__ == "__main__":
args = get_args()
vocoder = load_vocoder(vocoder_name=args.vocoder, device="cpu", hf_cache_dir=None)
if args.vocoder == "vocos":
export_VocosVocoder(vocoder, args.output_path, verbose=False)

View File

@@ -0,0 +1,43 @@
#!/bin/bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
ENGINE_PATH=$2
echo "ONNX_PATH: $ONNX_PATH"
echo "ENGINE_PATH: $ENGINE_PATH"
PRECISION="fp32"
MIN_BATCH_SIZE=1
OPT_BATCH_SIZE=1
MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
MEL_MAX_SHAPE="${MAX_BATCH_SIZE}x100x${MAX_INPUT_LENGTH}"
${TRTEXEC} \
--minShapes="mel:${MEL_MIN_SHAPE}" \
--optShapes="mel:${MEL_OPT_SHAPE}" \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}

View File

@@ -0,0 +1,36 @@
#! /usr/bin/env python3
from argparse import ArgumentParser
from string import Template
def main(file_path, substitutions, in_place, participant_ids):
with open(file_path) as f:
pbtxt = Template(f.read())
sub_dict = {"max_queue_size": 0}
sub_dict["participant_ids"] = participant_ids
for sub in substitutions.split(","):
key, value = sub.split(":")
sub_dict[key] = value
pbtxt = pbtxt.safe_substitute(sub_dict)
if in_place:
with open(file_path, "w") as f:
f.write(pbtxt)
else:
print(pbtxt)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument(
"substitutions",
help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2...",
)
parser.add_argument("--in_place", "-i", action="store_true", help="do the operation in-place")
parser.add_argument("--participant_ids", help="Participant IDs for the model", default="")
args = parser.parse_args()
main(**vars(args))

View File

@@ -9,7 +9,7 @@ mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1000000
wanted_max_updates = 1200000
# train params
gpus = 8
@@ -24,7 +24,7 @@ updates_per_epoch = total_hours / mini_batch_hours
# result
epochs = wanted_max_updates / updates_per_epoch
print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
print(f"epochs should be set to: {epochs:.0f} ({epochs / grad_accum:.1f} x gd_acum {grad_accum})")
print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
# print(f" or approx. 0/{steps_per_epoch:.0f} steps")

View File

@@ -1,12 +1,13 @@
import sys
import os
import sys
sys.path.append(os.getcwd())
from f5_tts.model import CFM, DiT
import torch
import thop
import torch
from f5_tts.model import CFM, DiT
""" ~155M """

View File

@@ -0,0 +1,63 @@
import asyncio
import logging
import socket
import time
import numpy as np
import pyaudio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
start_time = time.time()
first_chunk_time = None
async def play_audio_stream():
nonlocal first_chunk_time
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
try:
while True:
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
if not data:
break
if data == b"END":
logger.info("End of audio received.")
break
audio_array = np.frombuffer(data, dtype=np.float32)
stream.write(audio_array.tobytes())
if first_chunk_time is None:
first_chunk_time = time.time()
finally:
stream.stop_stream()
stream.close()
p.terminate()
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
try:
data_to_send = f"{text}".encode("utf-8")
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
await play_audio_stream()
except Exception as e:
logger.error(f"Error in listen_to_F5TTS: {e}")
finally:
client_socket.close()
if __name__ == "__main__":
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
asyncio.run(listen_to_F5TTS(text_to_send))

View File

@@ -1,7 +1,6 @@
import argparse
import gc
import logging
import numpy as np
import queue
import socket
import struct
@@ -10,19 +9,22 @@ import traceback
import wave
from importlib.resources import files
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.model.backbones.dit import DiT
from f5_tts.infer.utils_infer import (
chunk_text,
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_batch_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -68,7 +70,7 @@ class AudioFileWriterThread(threading.Thread):
class TTSStreamingProcessor:
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
def __init__(self, model, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
self.device = device or (
"cuda"
if torch.cuda.is_available()
@@ -78,21 +80,24 @@ class TTSStreamingProcessor:
if torch.backends.mps.is_available()
else "cpu"
)
self.mel_spec_type = "vocos"
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
self.model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
self.model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.sampling_rate = model_cfg.model.mel_spec.target_sample_rate
self.model = self.load_ema_model(ckpt_file, vocab_file, dtype)
self.vocoder = self.load_vocoder_model()
self.sampling_rate = 24000
self.update_reference(ref_audio, ref_text)
self._warm_up()
self.file_writer_thread = None
self.first_package = True
def load_ema_model(self, ckpt_file, vocab_file, dtype):
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
return load_model(
model_cls=model_cls,
model_cfg=model_cfg,
self.model_cls,
self.model_arc,
ckpt_path=ckpt_file,
mel_spec_type=self.mel_spec_type,
vocab_file=vocab_file,
@@ -212,9 +217,14 @@ if __name__ == "__main__":
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", default=9998)
parser.add_argument(
"--model",
default="F5TTS_v1_Base",
help="The model name, e.g. F5TTS_v1_Base",
)
parser.add_argument(
"--ckpt_file",
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_Base/model_1200000.safetensors")),
default=str(hf_hub_download(repo_id="SWivid/F5-TTS", filename="F5TTS_v1_Base/model_1250000.safetensors")),
help="Path to the model checkpoint file",
)
parser.add_argument(
@@ -242,6 +252,7 @@ if __name__ == "__main__":
try:
# Initialize the processor with the model and vocoder
processor = TTSStreamingProcessor(
model=args.model,
ckpt_file=args.ckpt_file,
vocab_file=args.vocab_file,
ref_audio=args.ref_audio,

View File

@@ -40,10 +40,10 @@ Once your datasets are prepared, you can start the training process.
accelerate config
# .yaml files are under src/f5_tts/configs directory
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml
# possible to overwrite accelerate and hydra config
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_Small_train.yaml ++datasets.batch_size_per_gpu=19200
accelerate launch --mixed_precision=fp16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml ++datasets.batch_size_per_gpu=19200
```
### 2. Finetuning practice
@@ -51,9 +51,13 @@ Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussio
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
The `use_ema = True` is harmful for early-stage finetuned checkpoints (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off and see if provide better results.
If want to finetune with a variant version e.g. *F5TTS_v1_Base_no_zero_init*, manually download pretrained checkpoint from model weight repository and fill in the path correspondingly on web interface.
### 3. Wandb Logging
If use tensorboard as logger, install it first with `pip install tensorboard`.
<ins>The `use_ema = True` might be harmful for early-stage finetuned checkpoints</ins> (which goes just few updates, thus ema weights still dominated by pretrained ones), try turn it off with finetune gradio option or `load_model(..., use_ema=False)`, see if offer better results.
### 3. W&B Logging
The `wandb/` dir will be created under path you run training/finetuning scripts.
@@ -62,7 +66,7 @@ By default, the training script does NOT use logging (assuming you didn't manual
To turn on wandb logging, you can either:
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/authorize and set the environment variable as follows:
On Mac & Linux:
@@ -75,7 +79,7 @@ On Windows:
```
set WANDB_API_KEY=<YOUR WANDB API KEY>
```
Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
Moreover, if you couldn't access W&B and want to log metrics offline, you can set the environment variable as follows:
```
export WANDB_MODE=offline

View File

@@ -1,12 +1,13 @@
import os
import sys
import signal
import subprocess # For invoking ffprobe
import shutil
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import sys
from contextlib import contextmanager
sys.path.append(os.getcwd())
import argparse
@@ -16,12 +17,10 @@ from importlib.resources import files
from pathlib import Path
import torchaudio
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
@@ -122,7 +121,7 @@ def prepare_csv_wavs_dir(input_dir, num_workers=None):
for future in tqdm(
chunk_futures,
total=len(chunk),
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}",
):
try:
result = future.result()
@@ -233,7 +232,7 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
dataset_name = out_dir.stem
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):

View File

@@ -7,20 +7,18 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
repetition_found,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
out_zh = {
@@ -198,7 +196,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if "ZH" in langs:
print(f"Bad zh transcription case: {total_bad_case_zh}")
if "EN" in langs:

View File

@@ -0,0 +1,94 @@
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
# prepares Emilia dataset with the new format w/ Emilia-YODAS
import json
import os
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import repetition_found
# Define filters for exclusion
out_en = set()
en_filters = ["ا", "", ""]
def process_audio_directory(audio_dir):
sub_result, durations, vocab_set = [], [], set()
bad_case_en = 0
for file in audio_dir.iterdir():
if file.suffix == ".json":
with open(file, "r") as f:
obj = json.load(f)
text = obj["text"]
if any(f in text for f in en_filters) or repetition_found(text, length=4):
bad_case_en += 1
continue
duration = obj["duration"]
audio_file = file.with_suffix(".mp3")
if audio_file.exists():
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result, duration_list, text_vocab_set = [], [], set()
total_bad_case_en = 0
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
dataset_path = Path(dataset_dir)
for sub_dir in dataset_path.iterdir():
if sub_dir.is_dir():
futures.append(executor.submit(process_audio_directory, sub_dir))
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_en = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_en += bad_case_en
executor.shutdown()
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"For {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "char"
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
dataset_name = f"Emilia_EN_{tokenizer}"
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
main()

View File

@@ -1,15 +1,17 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def deal_with_audio_dir(audio_dir):
@@ -72,7 +74,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":

View File

@@ -1,14 +1,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def main():
@@ -50,7 +52,7 @@ def main():
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
if __name__ == "__main__":

View File

@@ -4,15 +4,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from tqdm import tqdm
import torchaudio
from datasets import Dataset
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin

View File

@@ -1,12 +1,13 @@
import argparse
import os
import shutil
from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from importlib.resources import files
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #
@@ -20,19 +21,14 @@ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
# -------------------------- Argument Parsing --------------------------- #
def parse_args():
# batch_size_per_gpu = 1000 settting for gpu 8GB
# batch_size_per_gpu = 1600 settting for gpu 12GB
# batch_size_per_gpu = 2000 settting for gpu 16GB
# batch_size_per_gpu = 3200 settting for gpu 24GB
# num_warmup_updates = 300 for 5000 sample about 10 hours
# change save_per_updates , last_per_updates change this value what you need ,
parser = argparse.ArgumentParser(description="Train CFM Model")
parser.add_argument(
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
"--exp_name",
type=str,
default="F5TTS_v1_Base",
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
help="Experiment name",
)
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
@@ -44,15 +40,15 @@ def parse_args():
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
parser.add_argument("--num_warmup_updates", type=int, default=20000, help="Warmup updates")
parser.add_argument("--save_per_updates", type=int, default=50000, help="Save checkpoint every N updates")
parser.add_argument(
"--keep_last_n_checkpoints",
type=int,
default=-1,
help="-1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints",
)
parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
parser.add_argument("--last_per_updates", type=int, default=5000, help="Save last checkpoint every N updates")
parser.add_argument("--finetune", action="store_true", help="Use Finetune")
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
parser.add_argument(
@@ -69,7 +65,7 @@ def parse_args():
action="store_true",
help="Log inferenced samples per ckpt save updates",
)
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
parser.add_argument("--logger", type=str, default=None, choices=[None, "wandb", "tensorboard"], help="logger")
parser.add_argument(
"--bnb_optimizer",
action="store_true",
@@ -88,19 +84,54 @@ def main():
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
# Model parameters based on experiment name
if args.exp_name == "F5TTS_Base":
if args.exp_name == "F5TTS_v1_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "F5TTS_Base":
wandb_resume_id = None
model_cls = DiT
model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
text_mask_padding=False,
conv_layers=4,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
ckpt_path = args.pretrain
elif args.exp_name == "E2TTS_Base":
wandb_resume_id = None
model_cls = UNetT
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cfg = dict(
dim=1024,
depth=24,
heads=16,
ff_mult=4,
text_mask_padding=False,
pe_attn_head=1,
)
if args.finetune:
if args.pretrain is None:
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
@@ -120,6 +151,7 @@ def main():
print("copy checkpoint for finetune")
# Use the tokenizer and tokenizer_path provided in the command line arguments
tokenizer = args.tokenizer
if tokenizer == "custom":
if not args.tokenizer_path:
@@ -156,7 +188,7 @@ def main():
save_per_updates=args.save_per_updates,
keep_last_n_checkpoints=args.keep_last_n_checkpoints,
checkpoint_path=checkpoint_path,
batch_size=args.batch_size_per_gpu,
batch_size_per_gpu=args.batch_size_per_gpu,
batch_size_type=args.batch_size_type,
max_samples=args.max_samples,
grad_accumulation_steps=args.grad_accumulation_steps,

View File

@@ -1,36 +1,36 @@
import threading
import queue
import re
import gc
import json
import os
import platform
import psutil
import queue
import random
import signal
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import threading
import time
from glob import glob
from importlib.resources import files
import click
import gradio as gr
import librosa
import numpy as np
import psutil
import torch
import torchaudio
from cached_path import cached_path
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import save_file
from safetensors.torch import load_file, save_file
from scipy.io import wavfile
from cached_path import cached_path
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.infer.utils_infer import transcribe
from importlib.resources import files
from f5_tts.model.utils import convert_char_to_pinyin
training_process = None
@@ -118,26 +118,28 @@ def load_settings(project_name):
# Default settings
default_settings = {
"exp_name": "F5TTS_Base",
"learning_rate": 1e-05,
"batch_size_per_gpu": 1000,
"exp_name": "F5TTS_v1_Base",
"learning_rate": 1e-5,
"batch_size_per_gpu": 3200,
"batch_size_type": "frame",
"max_samples": 64,
"grad_accumulation_steps": 1,
"max_grad_norm": 1,
"max_grad_norm": 1.0,
"epochs": 100,
"num_warmup_updates": 2,
"save_per_updates": 300,
"num_warmup_updates": 100,
"save_per_updates": 500,
"keep_last_n_checkpoints": -1,
"last_per_updates": 100,
"finetune": True,
"file_checkpoint_train": "",
"tokenizer_type": "pinyin",
"tokenizer_file": "",
"mixed_precision": "none",
"logger": "wandb",
"mixed_precision": "fp16",
"logger": "none",
"bnb_optimizer": False,
}
if device == "mps":
default_settings["mixed_precision"] = "none"
# Load settings from file if it exists
if os.path.isfile(file_setting):
@@ -361,27 +363,27 @@ def terminate_process(pid):
def start_training(
dataset_name="",
exp_name="F5TTS_Base",
learning_rate=1e-4,
batch_size_per_gpu=400,
batch_size_type="frame",
max_samples=64,
grad_accumulation_steps=1,
max_grad_norm=1.0,
epochs=11,
num_warmup_updates=200,
save_per_updates=400,
keep_last_n_checkpoints=-1,
last_per_updates=800,
finetune=True,
file_checkpoint_train="",
tokenizer_type="pinyin",
tokenizer_file="",
mixed_precision="fp16",
stream=False,
logger="wandb",
ch_8bit_adam=False,
dataset_name,
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
stream,
logger,
ch_8bit_adam,
):
global training_process, tts_api, stop_signal
@@ -458,7 +460,10 @@ def start_training(
cmd += f" --tokenizer {tokenizer_type}"
cmd += f" --log_samples --logger {logger}"
if logger != "none":
cmd += f" --logger {logger}"
cmd += " --log_samples"
if ch_8bit_adam:
cmd += " --bnb_optimizer"
@@ -515,7 +520,7 @@ def start_training(
training_process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
)
yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
stdout_queue = queue.Queue()
stderr_queue = queue.Queue()
@@ -584,7 +589,11 @@ def start_training(
gr.update(interactive=True),
)
else:
yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
yield (
"Training complete or paused ...",
gr.update(interactive=False),
gr.update(interactive=True),
)
break
# Small sleep to prevent CPU thrashing
@@ -598,9 +607,9 @@ def start_training(
time.sleep(1)
if training_process is None:
text_info = "train stop"
text_info = "Train stopped !"
else:
text_info = "train complete !"
text_info = "Train complete at end !"
except Exception as e: # Catch all exceptions
# Ensure that we reset the training process variable in case of an error
@@ -615,11 +624,11 @@ def stop_training():
global training_process, stop_signal
if training_process is None:
return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
terminate_process_tree(training_process.pid)
# training_process = None
stop_signal = True
return "train stop", gr.update(interactive=True), gr.update(interactive=False)
return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
def get_list_projects():
@@ -797,14 +806,14 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
print(f"Error processing {file_audio}: {e}")
continue
if duration < 1 or duration > 25:
if duration > 25:
error_files.append([file_audio, "duration > 25 sec"])
if duration < 1 or duration > 30:
if duration > 30:
error_files.append([file_audio, "duration > 30 sec"])
if duration < 1:
error_files.append([file_audio, "duration < 1 sec "])
continue
if len(text) < 3:
error_files.append([file_audio, "very small text len 3"])
error_files.append([file_audio, "very short text length 3"])
continue
text = clear_text(text)
@@ -871,40 +880,37 @@ def check_user(value):
def calculate_train(
name_project,
epochs,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
learning_rate,
num_warmup_updates,
save_per_updates,
last_per_updates,
finetune,
):
path_project = os.path.join(path_data, name_project)
file_duraction = os.path.join(path_project, "duration.json")
file_duration = os.path.join(path_project, "duration.json")
if not os.path.isfile(file_duraction):
hop_length = 256
sampling_rate = 24000
if not os.path.isfile(file_duration):
return (
1000,
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_updates,
"project not found !",
learning_rate,
)
with open(file_duraction, "r") as file:
with open(file_duration, "r") as file:
data = json.load(file)
duration_list = data["duration"]
samples = len(duration_list)
hours = sum(duration_list) / 3600
# if torch.cuda.is_available():
# gpu_properties = torch.cuda.get_device_properties(0)
# total_memory = gpu_properties.total_memory / (1024**3)
# elif torch.backends.mps.is_available():
# total_memory = psutil.virtual_memory().available / (1024**3)
max_sample_length = max(duration_list) * sampling_rate / hop_length
total_samples = len(duration_list)
total_duration = sum(duration_list)
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
@@ -912,64 +918,39 @@ def calculate_train(
for i in range(gpu_count):
gpu_properties = torch.cuda.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3) # in GB
elif torch.xpu.is_available():
gpu_count = torch.xpu.device_count()
total_memory = 0
for i in range(gpu_count):
gpu_properties = torch.xpu.get_device_properties(i)
total_memory += gpu_properties.total_memory / (1024**3)
elif torch.backends.mps.is_available():
gpu_count = 1
total_memory = psutil.virtual_memory().available / (1024**3)
avg_gpu_memory = total_memory / gpu_count
# rough estimate of batch size
if batch_size_type == "frame":
batch = int(total_memory * 0.5)
batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
batch_size_per_gpu = int(38400 / batch)
else:
batch_size_per_gpu = int(total_memory / 8)
batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
batch = batch_size_per_gpu
batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
elif batch_size_type == "sample":
batch_size_per_gpu = int(200 / (total_duration / total_samples))
if batch_size_per_gpu <= 0:
batch_size_per_gpu = 1
if total_samples < 64:
max_samples = int(total_samples * 0.25)
if samples < 64:
max_samples = int(samples * 0.25)
else:
max_samples = 64
num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
num_warmup_updates = int(samples * 0.05)
save_per_updates = int(samples * 0.10)
last_per_updates = int(save_per_updates * 0.25)
# take 1.2M updates as the maximum
max_updates = 1200000
max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
if last_per_updates <= 0:
last_per_updates = 2
if batch_size_type == "frame":
mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
updates_per_epoch = total_duration / mini_batch_duration
elif batch_size_type == "sample":
updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
total_hours = hours
mel_hop_length = 256
mel_sampling_rate = 24000
# target
wanted_max_updates = 1000000
# train params
gpus = gpu_count
frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
grad_accum = 1
# intermediate
mini_batch_frames = frames_per_gpu * grad_accum * gpus
mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
updates_per_epoch = total_hours / mini_batch_hours
# steps_per_epoch = updates_per_epoch * grad_accum
epochs = wanted_max_updates / updates_per_epoch
epochs = int(max_updates / updates_per_epoch)
if finetune:
learning_rate = 1e-5
@@ -977,32 +958,32 @@ def calculate_train(
learning_rate = 7.5e-5
return (
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_updates,
samples,
learning_rate,
int(epochs),
total_samples,
)
def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
try:
checkpoint = torch.load(checkpoint_path, weights_only=True)
print("Original Checkpoint Keys:", checkpoint.keys())
ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
if ema_model_state_dict is None:
return "No 'ema_model_state_dict' found in the checkpoint."
to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
try:
model_state_dict_to_retain = checkpoint[to_retain]
except KeyError:
return f"{to_retain} not found in the checkpoint."
if safetensors:
new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
save_file(ema_model_state_dict, new_checkpoint_path)
save_file(model_state_dict_to_retain, new_checkpoint_path)
else:
new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
torch.save(new_checkpoint, new_checkpoint_path)
return f"New checkpoint saved at: {new_checkpoint_path}"
@@ -1021,7 +1002,11 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
ckpt = torch.load(ckpt_path, map_location="cpu")
if ckpt_path.endswith(".safetensors"):
ckpt = load_file(ckpt_path, device="cpu")
ckpt = {"ema_model_state_dict": ckpt}
elif ckpt_path.endswith(".pt"):
ckpt = torch.load(ckpt_path, map_location="cpu")
ema_sd = ckpt.get("ema_model_state_dict", {})
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
@@ -1039,7 +1024,10 @@ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
torch.save(ckpt, new_ckpt_path)
if new_ckpt_path.endswith(".safetensors"):
save_file(ema_sd, new_ckpt_path)
elif new_ckpt_path.endswith(".pt"):
torch.save(ckpt, new_ckpt_path)
return vocab_new
@@ -1089,9 +1077,11 @@ def vocab_extend(project_name, symbols, model_type):
with open(file_vocab_project, "w", encoding="utf-8") as f:
f.write("\n".join(vocab))
if model_type == "F5-TTS":
if model_type == "F5TTS_v1_Base":
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
elif model_type == "F5TTS_Base":
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
else:
elif model_type == "E2TTS_Base":
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
vocab_size_new = len(miss_symbols)
@@ -1101,7 +1091,7 @@ def vocab_extend(project_name, symbols, model_type):
os.makedirs(new_ckpt_path, exist_ok=True)
# Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
@@ -1149,7 +1139,7 @@ def vocab_check(project_name):
info = "You can train using your language !"
else:
vocab_miss = ",".join(miss_symbols)
info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
return info, vocab_miss
@@ -1231,21 +1221,24 @@ def infer(
vocab_file = os.path.join(path_data, project, "vocab.txt")
tts_api = F5TTS(
model_type=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
)
print("update >> ", device_test, file_checkpoint, use_ema)
if seed == -1: # -1 used for random
seed = None
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
gen_text=gen_text.lower().strip(),
ref_text=ref_text.lower().strip(),
ref_file=ref_audio,
ref_text=ref_text.lower().strip(),
gen_text=gen_text.lower().strip(),
nfe_step=nfe_step,
file_wave=f.name,
speed=speed,
seed=seed,
remove_silence=remove_silence,
file_wave=f.name,
seed=seed,
)
return f.name, tts_api.device, str(tts_api.seed)
@@ -1404,14 +1397,14 @@ def get_audio_select(file_sample):
with gr.Blocks() as app:
gr.Markdown(
"""
# E2/F5 TTS Automatic Finetune
# F5 TTS Automatic Finetune
This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
The checkpoints support English and Chinese.
The pretrained checkpoints support English and Chinese.
For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
"""
@@ -1454,9 +1447,9 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
)
audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
txt_lang = gr.Text(label="Language", value="English")
txt_lang = gr.Textbox(label="Language", value="English")
bt_transcribe = bt_create = gr.Button("Transcribe")
txt_info_transcribe = gr.Text(label="Info", value="")
txt_info_transcribe = gr.Textbox(label="Info", value="")
bt_transcribe.click(
fn=transcribe_all,
inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
@@ -1467,7 +1460,7 @@ Skip this step if you have your dataset, metadata.csv, and a folder wavs with al
random_sample_transcribe = gr.Button("Random Sample")
with gr.Row():
random_text_transcribe = gr.Text(label="Text")
random_text_transcribe = gr.Textbox(label="Text")
random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
random_sample_transcribe.click(
@@ -1482,13 +1475,15 @@ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are incl
```""")
check_button = gr.Button("Check Vocab")
txt_info_check = gr.Text(label="Info", value="")
txt_info_check = gr.Textbox(label="Info", value="")
gr.Markdown("""```plaintext
Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
```""")
exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
exp_name_extend = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
)
with gr.Row():
txt_extend = gr.Textbox(
@@ -1500,7 +1495,7 @@ Using the extended model, you can finetune to a new language that is missing sym
txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
extend_button = gr.Button("Extend")
txt_info_extend = gr.Text(label="Info", value="")
txt_info_extend = gr.Textbox(label="Info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
@@ -1540,8 +1535,8 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
bt_prepare = bt_create = gr.Button("Prepare")
txt_info_prepare = gr.Text(label="Info", value="")
txt_vocab_prepare = gr.Text(label="Vocab", value="")
txt_info_prepare = gr.Textbox(label="Info", value="")
txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
bt_prepare.click(
fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
@@ -1550,61 +1545,73 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
random_sample_prepare = gr.Button("Random Sample")
with gr.Row():
random_text_prepare = gr.Text(label="Tokenizer")
random_text_prepare = gr.Textbox(label="Tokenizer")
random_audio_prepare = gr.Audio(label="Audio", type="filepath")
random_sample_prepare.click(
fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
)
with gr.TabItem("Train Data"):
with gr.TabItem("Train Model"):
gr.Markdown("""```plaintext
The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
```""")
with gr.Row():
bt_calculate = bt_create = gr.Button("Auto Settings")
exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
tokenizer_file = gr.Textbox(label="Tokenizer File")
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="Finetune")
lb_samples = gr.Label(label="Samples")
batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
bt_calculate = bt_create = gr.Button("Auto Settings")
with gr.Row():
ch_finetune = bt_create = gr.Checkbox(label="Finetune", value=True)
tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint", value="")
epochs = gr.Number(label="Epochs")
learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
max_grad_norm = gr.Number(label="Max Gradient Norm")
num_warmup_updates = gr.Number(label="Warmup Updates")
with gr.Row():
exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
batch_size_type = gr.Radio(
label="Batch Size Type",
choices=["frame", "sample"],
info="frame is calculated as seconds * sampling_rate / hop_length",
)
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
grad_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
)
max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
with gr.Row():
batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
max_samples = gr.Number(label="Max Samples", value=64)
with gr.Row():
grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
with gr.Row():
epochs = gr.Number(label="Epochs", value=10)
num_warmup_updates = gr.Number(label="Warmup Updates", value=2)
with gr.Row():
save_per_updates = gr.Number(label="Save per Updates", value=300)
save_per_updates = gr.Number(
label="Save per Updates",
info="Save intermediate checkpoints every N updates",
minimum=10,
)
keep_last_n_checkpoints = gr.Number(
label="Keep Last N Checkpoints",
value=-1,
step=1,
precision=0,
info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
minimum=-1,
)
last_per_updates = gr.Number(label="Last per Updates", value=100)
last_per_updates = gr.Number(
label="Last per Updates",
info="Save latest checkpoint with suffix _last.pt every N updates",
minimum=10,
)
gr.Radio(label="") # placeholder
with gr.Row():
ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
with gr.Column():
start_button = gr.Button("Start Training")
stop_button = gr.Button("Stop Training", interactive=False)
if projects_selelect is not None:
(
@@ -1651,7 +1658,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
ch_8bit_adam.value = bnb_optimizer_value
ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
txt_info_train = gr.Text(label="Info", value="")
txt_info_train = gr.Textbox(label="Info", value="")
list_audios, select_audio = get_audio_project(projects_selelect, False)
@@ -1718,23 +1725,21 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
fn=calculate_train,
inputs=[
cm_project,
epochs,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
learning_rate,
num_warmup_updates,
save_per_updates,
last_per_updates,
ch_finetune,
],
outputs=[
epochs,
learning_rate,
batch_size_per_gpu,
max_samples,
num_warmup_updates,
save_per_updates,
last_per_updates,
lb_samples,
learning_rate,
epochs,
],
)
@@ -1744,25 +1749,25 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
def setup_load_settings():
output_components = [
exp_name, # 1
learning_rate, # 2
batch_size_per_gpu, # 3
batch_size_type, # 4
max_samples, # 5
grad_accumulation_steps, # 6
max_grad_norm, # 7
epochs, # 8
num_warmup_updates, # 9
save_per_updates, # 10
keep_last_n_checkpoints, # 11
last_per_updates, # 12
ch_finetune, # 13
file_checkpoint_train, # 14
tokenizer_type, # 15
tokenizer_file, # 16
mixed_precision, # 17
cd_logger, # 18
ch_8bit_adam, # 19
exp_name,
learning_rate,
batch_size_per_gpu,
batch_size_type,
max_samples,
grad_accumulation_steps,
max_grad_norm,
epochs,
num_warmup_updates,
save_per_updates,
keep_last_n_checkpoints,
last_per_updates,
ch_finetune,
file_checkpoint_train,
tokenizer_type,
tokenizer_file,
mixed_precision,
cd_logger,
ch_8bit_adam,
]
return output_components
@@ -1782,19 +1787,23 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
with gr.TabItem("Test Model"):
gr.Markdown("""```plaintext
SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
```""")
exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
exp_name = gr.Radio(
label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
)
list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
with gr.Row():
nfe_step = gr.Number(label="NFE Step", value=32)
speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
seed = gr.Number(label="Seed", value=-1, minimum=-1)
seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
remove_silence = gr.Checkbox(label="Remove Silence")
ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
with gr.Row():
ch_use_ema = gr.Checkbox(
label="Use EMA", value=True, info="Turn off at early stage might offer better results"
)
cm_checkpoint = gr.Dropdown(
choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
)
@@ -1802,20 +1811,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
random_sample_infer = gr.Button("Random Sample")
ref_text = gr.Textbox(label="Ref Text")
ref_audio = gr.Audio(label="Audio Ref", type="filepath")
gen_text = gr.Textbox(label="Gen Text")
ref_text = gr.Textbox(label="Reference Text")
ref_audio = gr.Audio(label="Reference Audio", type="filepath")
gen_text = gr.Textbox(label="Text to Generate")
random_sample_infer.click(
fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
)
with gr.Row():
txt_info_gpu = gr.Textbox("", label="Device")
seed_info = gr.Text(label="Seed :")
check_button_infer = gr.Button("Infer")
txt_info_gpu = gr.Textbox("", label="Inference on Device :")
seed_info = gr.Textbox(label="Used Random Seed :")
check_button_infer = gr.Button("Inference")
gen_audio = gr.Audio(label="Audio Gen", type="filepath")
gen_audio = gr.Audio(label="Generated Audio", type="filepath")
check_button_infer.click(
fn=infer,
@@ -1838,18 +1847,20 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
with gr.TabItem("Reduce Checkpoint"):
with gr.TabItem("Prune Checkpoint"):
gr.Markdown("""```plaintext
Reduce the model size from 5GB to 1.3GB. The new checkpoint can be used for inference or fine-tuning afterward, but it cannot be used to continue training.
Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
```""")
txt_path_checkpoint = gr.Text(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Text(label="Path to Output:")
ch_safetensors = gr.Checkbox(label="Safetensors", value="")
txt_info_reduse = gr.Text(label="Info", value="")
reduse_button = gr.Button("Reduce")
txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
with gr.Row():
ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
txt_info_reduse = gr.Textbox(label="Info", value="")
reduse_button = gr.Button("Prune")
reduse_button.click(
fn=extract_and_save_ema_model,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
fn=prune_checkpoint,
inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
outputs=[txt_info_reduse],
)

View File

@@ -4,70 +4,71 @@ import os
from importlib.resources import files
import hydra
from omegaconf import OmegaConf
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(cfg):
tokenizer = cfg.model.tokenizer
mel_spec_type = cfg.model.mel_spec.mel_spec_type
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
def main(model_cfg):
model_cls = hydra.utils.get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
tokenizer = model_cfg.model.tokenizer
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
exp_name = f"{model_cfg.model.name}_{mel_spec_type}_{model_cfg.model.tokenizer}_{model_cfg.datasets.name}"
wandb_resume_id = None
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = cfg.datasets.name
tokenizer_path = model_cfg.datasets.name
else:
tokenizer_path = cfg.model.tokenizer_path
tokenizer_path = model_cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
if "F5TTS" in cfg.model.name:
model_cls = DiT
elif "E2TTS" in cfg.model.name:
model_cls = UNetT
wandb_resume_id = None
model = CFM(
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=cfg.model.mel_spec,
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=model_cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=model_cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=cfg.optim.epochs,
learning_rate=cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", -1),
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
batch_size=cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type,
max_samples=cfg.datasets.max_samples,
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
max_grad_norm=cfg.optim.max_grad_norm,
logger=cfg.ckpts.logger,
epochs=model_cfg.optim.epochs,
learning_rate=model_cfg.optim.learning_rate,
num_warmup_updates=model_cfg.optim.num_warmup_updates,
save_per_updates=model_cfg.ckpts.save_per_updates,
keep_last_n_checkpoints=model_cfg.ckpts.keep_last_n_checkpoints,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{model_cfg.ckpts.save_dir}")),
batch_size_per_gpu=model_cfg.datasets.batch_size_per_gpu,
batch_size_type=model_cfg.datasets.batch_size_type,
max_samples=model_cfg.datasets.max_samples,
grad_accumulation_steps=model_cfg.optim.grad_accumulation_steps,
max_grad_norm=model_cfg.optim.max_grad_norm,
logger=model_cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_updates=cfg.ckpts.last_per_updates,
log_samples=True,
bnb_optimizer=cfg.optim.bnb_optimizer,
last_per_updates=model_cfg.ckpts.last_per_updates,
log_samples=model_cfg.ckpts.log_samples,
bnb_optimizer=model_cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
is_local_vocoder=model_cfg.model.vocoder.is_local,
local_vocoder_path=model_cfg.model.vocoder.local_path,
model_cfg_dict=OmegaConf.to_container(model_cfg, resolve=True),
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=cfg.datasets.num_workers,
num_workers=model_cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)