81 Commits

Author SHA1 Message Date
Yushen CHEN
f2a4f8581f Update runtime README 2025-10-22 08:37:32 +08:00
SWivid
a17c5ae435 pytorch imple.: fix batch 1 inference from last commit 2025-10-22 00:31:56 +00:00
SWivid
a0b8fb5df2 runtime trtllm: minor fixes. pytorch: update text_embedding logic to correct v0 batching. 2025-10-22 00:19:45 +00:00
SWivid
c8bfc3aa3d runtime trtllm: support v1 and custom 2025-10-21 22:02:25 +00:00
SWivid
8d3ec72159 runtime trtllm: clean-up v0 code, several fixes. 2025-10-20 10:30:58 +00:00
SWivid
65ada48a62 set attn related default value for unet-t backbone: #1192 2025-10-09 06:51:25 +00:00
SWivid
77d3ec623b v1.1.9 2025-09-13 13:42:33 +08:00
SWivid
186799d6dc remove numpy<=1.26.4 for python_version>=3.11 #1162; update links 2025-09-13 13:40:55 +08:00
Yushen CHEN
31bb78f2ab Update badge links 2025-09-03 15:12:24 +08:00
SWivid
e61824009a v1.1.8 2025-08-28 12:33:37 +00:00
SWivid
06a74910bd add option for text embedding late average upsampling 2025-08-28 11:46:11 +00:00
Yushen CHEN
ac3c43595c delete .github/workflows/sync-hf.yaml for online space stablility 2025-08-27 06:52:18 +08:00
Jim
605fa13b42 Fix raw.arrow missing rows (#1145)
* fix raw.arrow missing rows

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-07-22 19:38:44 +08:00
Yushen CHEN
5f35f27230 update pyproject.toml 2025-07-15 17:28:41 +08:00
Yushen CHEN
c96c3aeed8 Update pyproject.toml 2025-07-14 14:36:26 +08:00
Yushen CHEN
9b60fe6a34 update pyproject.toml, set gradio<=5.35.0 until fix #1126 2025-07-14 14:29:19 +08:00
SWivid
a275798a2f last fix patch-1 2025-07-08 18:44:47 +08:00
SWivid
efc7a7498b fix #1111 #1037 remove redundant unwrap_model for AcceleratedOptimizer; which has no attribute '_modules' thus conflict with has_compiled_regions check introduced in accelerate v1.7.0 2025-07-08 18:39:43 +08:00
SWivid
9842314127 update slicer in finetune_gradio, legacy min_length 2s changed to 20s 2025-07-08 16:59:46 +08:00
SWivid
69b0e0110e v1.1.6 fla support, several changed for finetune and infer-cli 2025-07-03 00:08:42 +08:00
SWivid
52c84776e5 fine-grained speed control for infer-cli. #1112 2025-07-02 23:41:55 +08:00
Danh Tran
ebbd7bd91f Update WAV File Naming and Dependencies 📝🔊 (#1091)
* Update infer_cli.py

* Update pyproject.toml

* formalized

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-24 23:23:00 +08:00
Yushen CHEN
ac42286d04 update finetune_gradio.py, not to force lower case
Not to force lower case, otherwise train infer mismatch with main infer code
2025-06-23 16:37:51 +08:00
Yushen CHEN
d937efa6f3 fix finetune_gradio.py, not to force lower case 2025-06-23 16:22:33 +08:00
Yushen CHEN
8975fca803 Merge pull request #1084 from starkwj/main
Speedup inference by batching CFG in DiT
2025-06-12 03:54:04 +08:00
SWivid
8b0053ad0c backward compatibility 2025-06-12 03:52:12 +08:00
SWivid
b3ef4ed1d7 correct imple., minor fixes 2025-06-12 03:32:19 +08:00
starkwj
b1a9438496 Batch cfg DiT forward 2025-06-11 09:03:30 +00:00
Zhikang Niu
0914170e98 Add flash_attn2 support attn_mask, minor fixes (#1066)
* add flash attn2 support
* update flash attn config in F5TTS
* fix minor bug of get the length of ref_mel

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-11 12:14:32 +08:00
SWivid
c6ebad0220 switch sync-hf workflow logic on release, avoid hidden space error with pypi/local_editable mismatch 2025-06-06 07:23:54 +08:00
SWivid
cfaba6387f refresh hf-space first 2025-06-06 07:22:02 +08:00
SWivid
646f34b20f v1.1.5 pypi 2025-06-06 07:08:59 +08:00
Jerrister Zheng
2e2acc6ea2 Update: Empirically Pruned Step Sampling (#1077)
* update Empirically Pruned Step Sampling

---------

Co-authored-by: Fast-F5-TTS <2942755472@qq.com>
Co-authored-by: SWivid <swivid@qq.com>
2025-06-04 22:59:30 +08:00
SWivid
6fbe7592f5 rebase default sample_rate to 24khz for runtime 2025-06-04 11:22:31 +08:00
Alice Yanagi
7e37bc5d9a Fix the duration computation in triton_trtllm/client_grpc.py (#1071)
* Update client_grpc.py

Using `actual_duration` to compute metrics like RTF.
2025-06-04 11:18:00 +08:00
SWivid
35f130ee85 minor update for infer-gradio 2025-06-04 06:11:49 +08:00
SWivid
e6469f705f update shared.md 2025-06-03 22:09:13 +08:00
SWivid
31cd818095 formatting 2025-06-03 21:23:47 +08:00
Yushen CHEN
1d13664b24 Merge pull request #1063 from ionite34/dev
Fix finetune training with spaces in file paths
2025-06-03 21:18:41 +08:00
Yushen CHEN
b27471ea06 Merge pull request #1072 from hvoss-techfak/main
German Model support
2025-06-03 21:18:25 +08:00
Hendric Voss
8fb55f107e Update SHARED.md 2025-06-03 14:08:30 +02:00
Hendric Voss
ccb380b752 Added German Model 2025-06-03 14:08:03 +02:00
Ionite
3027b43953 Fix training with file path spaces 2025-05-28 15:24:35 -04:00
SWivid
ecd1c3949a Add py312 check for tempfile delete_on_close keyword 2025-05-22 23:10:29 +08:00
SWivid
2968aa184f v1.1.5 several fixes 2025-05-22 17:41:10 +08:00
SWivid
fb26b6d93e Fix #1046 tempfile related bug 2025-05-22 17:40:14 +08:00
SWivid
f7f266cdd9 preprocess only once. Fix #1043 2025-05-21 02:26:05 +08:00
SWivid
695c735737 Exclude broken dependency version with accelerate 2025-05-16 17:48:41 +08:00
SWivid
3e2a07da1d Update README.md & minor fixes 2025-05-11 19:40:37 +08:00
SWivid
c47687487c minor fix for vocab check in finetune_gradio 2025-05-05 23:32:00 +08:00
SWivid
ac79d0ec1e v1.1.4 2025-05-05 04:05:25 +08:00
SWivid
dad398c0c1 Bug Fix #1015
Ensure custom config hashable in
2025-05-05 03:55:05 +08:00
SWivid
3d969bf78d minor fix for backward compatibility to gradio multistyle feature 2025-05-05 02:07:19 +08:00
SWivid
7c741c05f9 v1.1.3 better infer_gradio with cherrypick and cache support 2025-05-05 01:42:41 +08:00
SWivid
6d1a1e886a formatting, sorting 2025-05-05 01:41:28 +08:00
SWivid
b4efcd836a Add cache feature. Retrieve previous generated segments, default cache size 100 2025-05-05 01:37:22 +08:00
SWivid
818b868fab Update infer_gradio.py. Enable seed selecting for multistyle generation 2025-05-05 00:58:24 +08:00
SWivid
e6fee5e9ba Update infer_gradio.py
Use gr.Column to ensure backward compatibility

Remove height attr from gr.File to avoid possible malposition across versions
2025-05-04 09:25:41 +08:00
Yushen CHEN
2de214c122 Merge pull request #1014 from fakerybakery/fix-gradio-app-250503
Fix Gradio app
2025-05-04 09:14:32 +08:00
mrfakename
2999f642ce Row -> Column 2025-05-03 17:59:07 -07:00
mrfakename
03cff73343 remove equal_height requirement
Seems to break Gradio demo.
2025-05-03 17:57:41 -07:00
mrfakename
63c513840d fix gradio app 2025-05-03 17:56:21 -07:00
SWivid
3e6b6c0c0c update infer_gradio.py. rename for consistency 2025-05-04 08:04:00 +08:00
SWivid
f00ac4d06b fix infer-gradio chat feature etc. 2025-05-04 08:00:16 +08:00
Yushen CHEN
b0658bfd24 Merge pull request #1013 from petermg/main
Update infer_gradio.py
2025-05-04 03:33:22 +08:00
petermg
0cae51d646 Update infer_gradio.py
Modified formatting
2025-05-03 12:07:58 -07:00
petermg
95976041f2 Update infer_gradio.py
Added "randomize seed" checkmark and option to specify seed showing last seed used and can manually enter the desired seed number.
2025-05-03 11:38:50 -07:00
petermg
ba1bf74215 Update infer_gradio.py
Modified it so that when you upload a text file, the text of that file will show in the text input window. Also made the text file upload window show up BELOW the text input display window.
2025-05-03 11:22:07 -07:00
petermg
536c29ac57 Update infer_gradio.py
Modified the UI to accept txt files as inputs
2025-05-02 12:45:39 -07:00
SWivid
c4c61b0110 v1.1.2 several updates
add data prepare script recipe for emilia-yodas; fix speech_edit.py; fix tensorrt-llm server code-switch
2025-05-02 03:13:33 +08:00
SWivid
5f80fec160 fix speech_edit.py 2025-04-26 02:10:39 +08:00
Yushen CHEN
178cb8afe6 Merge pull request #986 from fakerybakery/emilia-v2
Add processing script for new Emilia dataset format
2025-04-19 14:16:37 +08:00
mrfakename
761c7ed938 Add processing script for new Emilia dataset format 2025-04-18 20:56:31 -07:00
Yushen CHEN
13fd6f8e07 Merge pull request #971 from tbxark-fork/main
chore: Update the model checkpoint path to use the cache path.
2025-04-14 15:54:50 +08:00
tbxark
b2284b6cff chore: Update the model checkpoint path to use the cache path. 2025-04-14 11:28:48 +08:00
SWivid
4b4359bc39 finetune_gradio not to use fp16 by default for mps device 2025-04-03 22:33:21 +08:00
SWivid
fe5c562212 v1.1.1 add benchmark and trtllm offline code 2025-04-03 18:33:48 +08:00
Yushen CHEN
2374f8ec39 Merge pull request #948 from yuekaizhang/trtllm_benchmark
[TRT-LLM] add benchmark code
2025-04-03 18:27:21 +08:00
Yuekai Zhang
f4f10bff6c fix comment 2025-04-03 02:44:59 -07:00
Yuekai Zhang
9771ec6a3a add benchmark code 2025-04-03 02:42:40 -07:00
SWivid
4b3cd13382 Update README.md 2025-04-03 15:04:42 +08:00
60 changed files with 2070 additions and 960 deletions

View File

@@ -1,18 +0,0 @@
name: Sync to HF Space
on:
push:
branches:
- main
jobs:
trigger_curl:
runs-on: ubuntu-latest
steps:
- name: Send cURL POST request
run: |
curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
-s \
-H "Content-Type: application/json" \
-d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"

View File

@@ -3,11 +3,14 @@ repos:
# Ruff version.
rev: v0.11.2
hooks:
# Run the linter.
- id: ruff
name: ruff linter
args: [--fix]
# Run the formatter.
- id: ruff-format
name: ruff formatter
- id: ruff
name: ruff sorter
args: [--select, I, --fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:

View File

@@ -2,11 +2,12 @@
[![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
[![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
[![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
[![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
[![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
[![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
[![lab](https://img.shields.io/badge/Peng%20Cheng-Lab-grey?labelColor=lightgrey)](https://www.pcl.ac.cn)
[![demo](https://img.shields.io/badge/GitHub-Demo-orange.svg)](https://swivid.github.io/F5-TTS/)
[![hfspace](https://img.shields.io/badge/🤗-HF%20Space-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
[![msspace](https://img.shields.io/badge/🤖-MS%20Space-blue)](https://modelscope.cn/studios/AI-ModelScope/E2-F5-TTS)
[![lab](https://img.shields.io/badge/🏫-X--LANCE-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
[![lab](https://img.shields.io/badge/🏫-SII-grey?labelColor=lightgrey)](https://www.sii.edu.cn/)
[![lab](https://img.shields.io/badge/🏫-PCL-grey?labelColor=lightgrey)](https://www.pcl.ac.cn)
<!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
@@ -26,8 +27,8 @@
### Create a separate environment if needed
```bash
# Create a python 3.10 conda env (you could also use virtualenv)
conda create -n f5-tts python=3.10
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
conda create -n f5-tts python=3.11
conda activate f5-tts
```
@@ -91,7 +92,7 @@ conda activate f5-tts
> ```bash
> git clone https://github.com/SWivid/F5-TTS.git
> cd F5-TTS
> # git submodule update --init --recursive # (optional, if need > bigvgan)
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
> pip install -e .
> ```
@@ -107,6 +108,21 @@ docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,targ
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
```
### Runtime
Deployment solution with Triton and TensorRT-LLM.
#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
## Inference
@@ -179,19 +195,6 @@ f5-tts_infer-cli -c custom.toml
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
```
### 3. Runtime
Deployment solution with Triton and TensorRT-LLM.
#### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs.
| Model | Concurrency | Avg Latency | RTF |
|-------|-------------|----------------|-------|
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
## Training

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.0"
version = "1.1.9"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -15,17 +15,17 @@ classifiers = [
]
dependencies = [
"accelerate>=0.33.0",
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
"bitsandbytes>0.37.0; platform_machine!='arm64' and platform_system!='Darwin'",
"cached_path",
"click",
"datasets",
"ema_pytorch>=0.5.2",
"gradio>=3.45.2",
"gradio>=5.0.0",
"hydra-core>=1.3.0",
"jieba",
"librosa",
"matplotlib",
"numpy<=1.26.4",
"numpy<=1.26.4; python_version<='3.10'",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
@@ -38,6 +38,7 @@ dependencies = [
"tqdm>=4.65.0",
"transformers",
"transformers_stream_generator",
"unidecode",
"vocos",
"wandb",
"x_transformers>=1.31.14",

View File

@@ -6,5 +6,5 @@ target-version = "py310"
dummy-variable-rgx = "^_.*$"
[lint.isort]
force-single-line = true
force-single-line = false
lines-after-imports = 2

View File

@@ -9,13 +9,13 @@ from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
transcribe,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
)
from f5_tts.model.utils import seed_everything
@@ -154,8 +154,8 @@ if __name__ == "__main__":
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
ref_text="Some call me nature, others call me mother nature.",
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -32,6 +32,8 @@ model:
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -4,6 +4,7 @@
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

View File

@@ -1,6 +1,7 @@
import os
import sys
sys.path.append(os.getcwd())
import argparse
@@ -23,6 +24,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
from f5_tts.model import CFM
from f5_tts.model.utils import get_tokenizer
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
@@ -146,10 +148,15 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
if not os.path.exists(ckpt_path):
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_librispeech_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))

View File

@@ -5,17 +5,16 @@ import json
import os
import sys
sys.path.append(os.getcwd())
import multiprocessing as mp
from importlib.resources import files
import numpy as np
from f5_tts.eval.utils_eval import (
get_seed_tts_test,
run_asr_wer,
run_sim,
)
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
rel_path = str(files("f5_tts").joinpath("../../"))

View File

@@ -126,8 +126,13 @@ def get_inference_prompt(
else:
text_list = text
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
ref_mel_len = ref_mel.shape[-1]
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
@@ -142,10 +147,6 @@ def get_inference_prompt(
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, (

View File

@@ -13,7 +13,7 @@ To avoid possible inference failures, make sure you have seen through the follow
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
@@ -129,6 +129,28 @@ ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
## API Usage
```python
from importlib.resources import files
from f5_tts.api import F5TTS
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
```
Check [api.py](../api.py) for more details.
## TensorRT-LLM Deployment
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
## Socket Real-time Service
Real-time voice output with chunk stream:

View File

@@ -22,6 +22,8 @@
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
- [French](#french)
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
- [German](#german)
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
- [Hindi](#hindi)
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
- [Italian](#italian)
@@ -97,6 +99,22 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
## German
#### F5-TTS Base @ de @ hvoss-techfak
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
```bash
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
## Hindi
#### F5-TTS Small @ hi @ SPRINGLab

View File

@@ -13,8 +13,8 @@ output_file = "infer_cli_story.wav"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
speed = 0.8 # will ignore global speed
[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""

View File

@@ -1 +1 @@
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land. [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] Goodbye, [main] said he, [country] “Im off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."

View File

@@ -12,22 +12,23 @@ import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from unidecode import unidecode
from f5_tts.infer.utils_infer import (
mel_spec_type,
target_rms,
cross_fade_duration,
nfe_step,
cfg_strength,
sway_sampling_coef,
speed,
fix_duration,
cross_fade_duration,
device,
fix_duration,
infer_process,
load_model,
load_vocoder,
mel_spec_type,
nfe_step,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
speed,
sway_sampling_coef,
target_rms,
)
@@ -112,6 +113,11 @@ parser.add_argument(
action="store_true",
help="To save each audio chunks during inference",
)
parser.add_argument(
"--no_legacy_text",
action="store_false",
help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
)
parser.add_argument(
"--remove_silence",
action="store_true",
@@ -197,6 +203,12 @@ output_file = args.output_file or config.get(
)
save_chunk = args.save_chunk or config.get("save_chunk", False)
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
if save_chunk and use_legacy_text:
print(
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
)
remove_silence = args.remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
@@ -321,9 +333,10 @@ def main():
text = re.sub(reg2, "", text)
ref_audio_ = voices[voice]["ref_audio"]
ref_text_ = voices[voice]["ref_text"]
local_speed = voices[voice].get("speed", speed)
gen_text_ = text.strip()
print(f"Voice: {voice}")
audio_segment, final_sample_rate, spectragram = infer_process(
audio_segment, final_sample_rate, spectrogram = infer_process(
ref_audio_,
ref_text_,
gen_text_,
@@ -335,7 +348,7 @@ def main():
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
speed=local_speed,
fix_duration=fix_duration,
device=device,
)
@@ -344,6 +357,8 @@ def main():
if save_chunk:
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
if use_legacy_text:
gen_text_ = unidecode(gen_text_)
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,

View File

@@ -3,9 +3,11 @@
import gc
import json
import os
import re
import tempfile
from collections import OrderedDict
from functools import lru_cache
from importlib.resources import files
import click
@@ -17,6 +19,7 @@ import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
import spaces
@@ -32,15 +35,16 @@ def gpu_decorator(func):
return func
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
tempfile_kwargs,
)
from f5_tts.model import DiT, UNetT
DEFAULT_TTS_MODEL = "F5-TTS_v1"
@@ -78,6 +82,8 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
vocab_path = str(cached_path(vocab_path))
if model_cfg is None:
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
elif isinstance(model_cfg, str):
model_cfg = json.loads(model_cfg)
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
@@ -90,7 +96,7 @@ chat_tokenizer_state = None
@gpu_decorator
def generate_response(messages, model, tokenizer):
def chat_model_inference(messages, model, tokenizer):
"""Generate response using Qwen"""
text = tokenizer.apply_chat_template(
messages,
@@ -112,6 +118,17 @@ def generate_response(messages, model, tokenizer):
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
@gpu_decorator
def load_text_from_file(file):
if file:
with open(file, "r", encoding="utf-8") as f:
text = f.read().strip()
else:
text = ""
return gr.update(value=text)
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
@gpu_decorator
def infer(
ref_audio_orig,
@@ -119,6 +136,7 @@ def infer(
gen_text,
model,
remove_silence,
seed,
cross_fade_duration=0.15,
nfe_step=32,
speed=1,
@@ -128,8 +146,15 @@ def infer(
gr.Warning("Please provide reference audio.")
return gr.update(), gr.update(), ref_text
# Set inference seed
if seed < 0 or seed > 2**31 - 1:
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
seed = np.random.randint(0, 2**31 - 1)
torch.manual_seed(seed)
used_seed = seed
if not gen_text.strip():
gr.Warning("Please enter text to generate.")
gr.Warning("Please enter text to generate or upload a text file.")
return gr.update(), gr.update(), ref_text
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
@@ -142,7 +167,7 @@ def infer(
show_info("Loading E2-TTS model...")
E2TTS_ema_model = load_e2tts()
ema_model = E2TTS_ema_model
elif isinstance(model, list) and model[0] == "Custom":
elif isinstance(model, tuple) and model[0] == "Custom":
assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
global custom_ema_model, pre_custom_path
if pre_custom_path != model[1]:
@@ -166,44 +191,59 @@ def infer(
# Remove silence
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
try:
sf.write(temp_path, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
finally:
os.unlink(temp_path)
final_wave = final_wave.squeeze().cpu().numpy()
# Save the spectrogram
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")
with gr.Blocks() as app_tts:
gr.Markdown("# Batched TTS")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
with gr.Row():
gen_text_input = gr.Textbox(
label="Text to Generate",
lines=10,
max_lines=40,
scale=4,
)
gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
ref_text_input = gr.Textbox(
label="Reference Text",
info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
lines=2,
)
remove_silence = gr.Checkbox(
label="Remove Silences",
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
value=False,
)
with gr.Row():
ref_text_input = gr.Textbox(
label="Reference Text",
info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
lines=2,
scale=4,
)
ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize Seed",
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
value=True,
scale=3,
)
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
with gr.Column(scale=4):
remove_silence = gr.Checkbox(
label="Remove Silences",
info="If undesired long silence(s) produced, turn on to automatically detect and crop.",
value=False,
)
speed_slider = gr.Slider(
label="Speed",
minimum=0.3,
@@ -238,21 +278,45 @@ with gr.Blocks() as app_tts:
ref_text_input,
gen_text_input,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
):
audio_out, spectrogram_path, ref_text_out = infer(
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
ref_audio_input,
ref_text_input,
gen_text_input,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=cross_fade_duration_slider,
nfe_step=nfe_slider,
speed=speed_slider,
)
return audio_out, spectrogram_path, ref_text_out
return audio_out, spectrogram_path, ref_text_out, used_seed
gen_text_file.upload(
load_text_from_file,
inputs=[gen_text_file],
outputs=[gen_text_input],
)
ref_text_file.upload(
load_text_from_file,
inputs=[ref_text_file],
outputs=[ref_text_input],
)
ref_audio_input.clear(
lambda: [None, None],
None,
[ref_text_input, ref_text_file],
)
generate_btn.click(
basic_tts,
@@ -261,35 +325,46 @@ with gr.Blocks() as app_tts:
ref_text_input,
gen_text_input,
remove_silence,
randomize_seed,
seed_input,
cross_fade_duration_slider,
nfe_slider,
speed_slider,
],
outputs=[audio_output, spectrogram_output, ref_text_input],
outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
)
def parse_speechtypes_text(gen_text):
# Pattern to find {speechtype}
pattern = r"\{(.*?)\}"
# Pattern to find {str} or {"name": str, "seed": int, "speed": float}
pattern = r"(\{.*?\})"
# Split the text by the pattern
tokens = re.split(pattern, gen_text)
segments = []
current_style = "Regular"
current_type_dict = {
"name": "Regular",
"seed": -1,
"speed": 1.0,
}
for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
segments.append({"style": current_style, "text": text})
current_type_dict["text"] = text
segments.append(current_type_dict)
else:
# This is style
style = tokens[i].strip()
current_style = style
# This is type
type_str = tokens[i].strip()
try: # if type dict
current_type_dict = json.loads(type_str)
except json.decoder.JSONDecodeError:
type_str = type_str[1:-1] # remove brace {}
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
return segments
@@ -300,44 +375,55 @@ with gr.Blocks() as app_multistyle:
"""
# Multiple Speech-Type Generation
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
"""
)
with gr.Row():
gr.Markdown(
"""
**Example Input:**
{Regular} Hello, I'd like to order a sandwich please.
{Surprised} What do you mean you're out of bread?
{Sad} I really wanted a sandwich though...
{Angry} You know what, darn you and your little shop!
{Whisper} I'll just go back home and cry now.
{Shouting} Why me?!
**Example Input:** <br>
{Regular} Hello, I'd like to order a sandwich please. <br>
{Surprised} What do you mean you're out of bread? <br>
{Sad} I really wanted a sandwich though... <br>
{Angry} You know what, darn you and your little shop! <br>
{Whisper} I'll just go back home and cry now. <br>
{Shouting} Why me?!
"""
)
gr.Markdown(
"""
**Example Input 2:**
{Speaker1_Happy} Hello, I'd like to order a sandwich please.
{Speaker2_Regular} Sorry, we're out of bread.
{Speaker1_Sad} I really wanted a sandwich though...
{Speaker2_Whisper} I'll give you the last one I was hiding.
**Example Input 2:** <br>
{"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
{"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
{"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
{"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
"""
)
gr.Markdown(
"Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
)
# Regular speech type (mandatory)
with gr.Row() as regular_row:
with gr.Column():
with gr.Row(variant="compact") as regular_row:
with gr.Column(scale=1, min_width=160):
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
regular_insert = gr.Button("Insert Label", variant="secondary")
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
with gr.Column(scale=3):
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
with gr.Column(scale=3):
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
with gr.Row():
regular_seed_slider = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
)
regular_speed_slider = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
# Regular speech type (max 100)
max_speech_types = 100
@@ -345,25 +431,55 @@ with gr.Blocks() as app_multistyle:
speech_type_names = [regular_name]
speech_type_audios = [regular_audio]
speech_type_ref_texts = [regular_ref_text]
speech_type_ref_text_files = [regular_ref_text_file]
speech_type_seeds = [regular_seed_slider]
speech_type_speeds = [regular_speed_slider]
speech_type_delete_btns = [None]
speech_type_insert_btns = [regular_insert]
# Additional speech types (99 more)
for i in range(max_speech_types - 1):
with gr.Row(visible=False) as row:
with gr.Column():
with gr.Row(variant="compact", visible=False) as row:
with gr.Column(scale=1, min_width=160):
name_input = gr.Textbox(label="Speech Type Name")
delete_btn = gr.Button("Delete Type", variant="secondary")
insert_btn = gr.Button("Insert Label", variant="secondary")
audio_input = gr.Audio(label="Reference Audio", type="filepath")
ref_text_input = gr.Textbox(label="Reference Text", lines=2)
delete_btn = gr.Button("Delete Type", variant="stop")
with gr.Column(scale=3):
audio_input = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column(scale=3):
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
with gr.Row():
seed_input = gr.Slider(
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
)
speed_input = gr.Slider(
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
)
with gr.Column(scale=1, min_width=160):
ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
speech_type_rows.append(row)
speech_type_names.append(name_input)
speech_type_audios.append(audio_input)
speech_type_ref_texts.append(ref_text_input)
speech_type_ref_text_files.append(ref_text_file_input)
speech_type_seeds.append(seed_input)
speech_type_speeds.append(speed_input)
speech_type_delete_btns.append(delete_btn)
speech_type_insert_btns.append(insert_btn)
# Global logic for all speech types
for i in range(max_speech_types):
speech_type_audios[i].clear(
lambda: [None, None],
None,
[speech_type_ref_texts[i], speech_type_ref_text_files[i]],
)
speech_type_ref_text_files[i].upload(
load_text_from_file,
inputs=[speech_type_ref_text_files[i]],
outputs=[speech_type_ref_texts[i]],
)
# Button to add speech type
add_speech_type_btn = gr.Button("Add Speech Type")
@@ -385,27 +501,44 @@ with gr.Blocks() as app_multistyle:
# Function to delete a speech type
def delete_speech_type_fn():
return gr.update(visible=False), None, None, None
return gr.update(visible=False), None, None, None, None
# Update delete button clicks
# Update delete button clicks and ref text file changes
for i in range(1, len(speech_type_delete_btns)):
speech_type_delete_btns[i].click(
delete_speech_type_fn,
outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
outputs=[
speech_type_rows[i],
speech_type_names[i],
speech_type_audios[i],
speech_type_ref_texts[i],
speech_type_ref_text_files[i],
],
)
# Text input for the prompt
gen_text_input_multistyle = gr.Textbox(
label="Text to Generate",
lines=10,
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
)
with gr.Row():
gen_text_input_multistyle = gr.Textbox(
label="Text to Generate",
lines=10,
max_lines=40,
scale=4,
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
)
gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
def make_insert_speech_type_fn(index):
def insert_speech_type_fn(current_text, speech_type_name):
def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
current_text = current_text or ""
speech_type_name = speech_type_name or "None"
updated_text = current_text + f"{{{speech_type_name}}} "
if not speech_type_name:
gr.Warning("Please enter speech type name before insert.")
return current_text
speech_type_dict = {
"name": speech_type_name,
"seed": speech_type_seed,
"speed": speech_type_speed,
}
updated_text = current_text + json.dumps(speech_type_dict) + " "
return updated_text
return insert_speech_type_fn
@@ -414,15 +547,24 @@ with gr.Blocks() as app_multistyle:
insert_fn = make_insert_speech_type_fn(i)
insert_btn.click(
insert_fn,
inputs=[gen_text_input_multistyle, speech_type_names[i]],
inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
outputs=gen_text_input_multistyle,
)
with gr.Accordion("Advanced Settings", open=False):
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
value=True,
)
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
with gr.Column():
show_cherrypick_multistyle = gr.Checkbox(
label="Show Cherry-pick Interface",
info="Turn on to show interface, picking seeds from previous generations.",
value=False,
)
with gr.Column():
remove_silence_multistyle = gr.Checkbox(
label="Remove Silences",
info="Turn on to automatically detect and crop long silences.",
value=True,
)
# Generate button
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
@@ -430,6 +572,30 @@ with gr.Blocks() as app_multistyle:
# Output audio
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
# Used seed gallery
cherrypick_interface_multistyle = gr.Textbox(
label="Cherry-pick Interface",
lines=10,
max_lines=40,
show_copy_button=True,
interactive=False,
visible=False,
)
# Logic control to show/hide the cherrypick interface
show_cherrypick_multistyle.change(
lambda is_visible: gr.update(visible=is_visible),
show_cherrypick_multistyle,
cherrypick_interface_multistyle,
)
# Function to load text to generate from file
gen_text_file_multistyle.upload(
load_text_from_file,
inputs=[gen_text_file_multistyle],
outputs=[gen_text_input_multistyle],
)
@gpu_decorator
def generate_multistyle_speech(
gen_text,
@@ -457,41 +623,60 @@ with gr.Blocks() as app_multistyle:
# For each segment, generate speech
generated_audio_segments = []
current_style = "Regular"
current_type_name = "Regular"
inference_meta_data = ""
for segment in segments:
style = segment["style"]
name = segment["name"]
seed_input = segment["seed"]
speed = segment["speed"]
text = segment["text"]
if style in speech_types:
current_style = style
if name in speech_types:
current_type_name = name
else:
gr.Warning(f"Type {style} is not available, will use Regular as default.")
current_style = "Regular"
gr.Warning(f"Type {name} is not available, will use Regular as default.")
current_type_name = "Regular"
try:
ref_audio = speech_types[current_style]["audio"]
ref_audio = speech_types[current_type_name]["audio"]
except KeyError:
gr.Warning(f"Please provide reference audio for type {current_style}.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
ref_text = speech_types[current_style].get("ref_text", "")
gr.Warning(f"Please provide reference audio for type {current_type_name}.")
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
ref_text = speech_types[current_type_name].get("ref_text", "")
# Generate speech for this segment
audio_out, _, ref_text_out = infer(
ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
) # show_info=print no pull to top when generating
if seed_input == -1:
seed_input = np.random.randint(0, 2**31 - 1)
# Generate or retrieve speech for this segment
audio_out, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
text,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=0,
speed=speed,
show_info=print, # no pull to top when generating
)
sr, audio_data = audio_out
generated_audio_segments.append(audio_data)
speech_types[current_style]["ref_text"] = ref_text_out
speech_types[current_type_name]["ref_text"] = ref_text_out
inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
# Concatenate all audio segments
if generated_audio_segments:
final_audio_data = np.concatenate(generated_audio_segments)
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
return (
[(sr, final_audio_data)]
+ [speech_types[name]["ref_text"] for name in speech_types]
+ [inference_meta_data]
)
else:
gr.Warning("No audio generated.")
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
generate_multistyle_btn.click(
generate_multistyle_speech,
@@ -504,7 +689,7 @@ with gr.Blocks() as app_multistyle:
+ [
remove_silence_multistyle,
],
outputs=[audio_output_multistyle] + speech_type_ref_texts,
outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
)
# Validation function to disable Generate button if speech types are missing
@@ -521,7 +706,7 @@ with gr.Blocks() as app_multistyle:
# Parse the gen_text to get the speech types used
segments = parse_speechtypes_text(gen_text)
speech_types_in_text = set(segment["style"] for segment in segments)
speech_types_in_text = set(segment["name"] for segment in segments)
# Check if all speech types in text are available
missing_speech_types = speech_types_in_text - speech_types_available
@@ -544,10 +729,10 @@ with gr.Blocks() as app_chat:
gr.Markdown(
"""
# Voice Chat
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript.
Have a conversation with an AI using your reference voice!
1. Upload a reference audio clip and optionally its transcript (via text or .txt file).
2. Load the chat model.
3. Record your message through your microphone.
3. Record your message through your microphone or type it.
4. The AI will respond using the reference voice.
"""
)
@@ -603,22 +788,35 @@ Have a conversation with an AI using your reference voice!
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
with gr.Column():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
scale=3,
)
ref_text_file_chat = gr.File(
label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
)
with gr.Row():
randomize_seed_chat = gr.Checkbox(
label="Randomize Seed",
value=True,
info="Uncheck to use the seed specified.",
scale=3,
)
seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
remove_silence_chat = gr.Checkbox(
label="Remove Silences",
value=True,
)
ref_text_chat = gr.Textbox(
label="Reference Text",
info="Optional: Leave blank to auto-transcribe",
lines=2,
)
system_prompt_chat = gr.Textbox(
label="System Prompt",
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
lines=2,
)
chatbot_interface = gr.Chatbot(label="Conversation")
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
with gr.Row():
with gr.Column():
@@ -635,140 +833,119 @@ Have a conversation with an AI using your reference voice!
send_btn_chat = gr.Button("Send Message")
clear_btn_chat = gr.Button("Clear Conversation")
conversation_state = gr.State(
value=[
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
)
# Modify process_audio_input to use model and tokenizer from state
# Modify process_audio_input to generate user input
@gpu_decorator
def process_audio_input(audio_path, text, history, conv_state):
def process_audio_input(conv_state, audio_path, text):
"""Handle audio or text input from user"""
if not audio_path and not text.strip():
return history, conv_state, ""
return conv_state
if audio_path:
text = preprocess_ref_audio_text(audio_path, text)[1]
if not text.strip():
return history, conv_state, ""
return conv_state
conv_state.append({"role": "user", "content": text})
history.append((text, None))
return conv_state
response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
# Use model and tokenizer from state to get text response
@gpu_decorator
def generate_text_response(conv_state, system_prompt):
"""Generate text response from AI"""
system_prompt_state = [{"role": "system", "content": system_prompt}]
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
conv_state.append({"role": "assistant", "content": response})
history[-1] = (text, response)
return history, conv_state, ""
return conv_state
@gpu_decorator
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
"""Generate TTS audio for AI response"""
if not history or not ref_audio:
return None
if not conv_state or not ref_audio:
return None, ref_text, seed_input
last_user_message, last_ai_response = history[-1]
if not last_ai_response:
return None
last_ai_response = conv_state[-1]["content"]
if not last_ai_response or conv_state[-1]["role"] != "assistant":
return None, ref_text, seed_input
audio_result, _, ref_text_out = infer(
if randomize_seed:
seed_input = np.random.randint(0, 2**31 - 1)
audio_result, _, ref_text_out, used_seed = infer(
ref_audio,
ref_text,
last_ai_response,
tts_model_choice,
remove_silence,
seed=seed_input,
cross_fade_duration=0.15,
speed=1.0,
show_info=print, # show_info=print no pull to top when generating
)
return audio_result, ref_text_out
return audio_result, ref_text_out, used_seed
def clear_conversation():
"""Reset the conversation"""
return [], [
{
"role": "system",
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
}
]
return [], None
def update_system_prompt(new_prompt):
"""Update the system prompt and reset the conversation"""
new_conv_state = [{"role": "system", "content": new_prompt}]
return [], new_conv_state
# Handle audio input
audio_input_chat.stop_recording(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
audio_input_chat,
ref_text_file_chat.upload(
load_text_from_file,
inputs=[ref_text_file_chat],
outputs=[ref_text_chat],
)
# Handle text input
text_input_chat.submit(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
text_input_chat,
)
for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
user_operation(
process_audio_input,
inputs=[chatbot_interface, audio_input_chat, text_input_chat],
outputs=[chatbot_interface],
).then(
generate_text_response,
inputs=[chatbot_interface, system_prompt_chat],
outputs=[chatbot_interface],
).then(
generate_audio_response,
inputs=[
chatbot_interface,
ref_audio_chat,
ref_text_chat,
remove_silence_chat,
randomize_seed_chat,
seed_input_chat,
],
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
).then(
lambda: [None, None],
None,
[audio_input_chat, text_input_chat],
)
# Handle send button
send_btn_chat.click(
process_audio_input,
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
outputs=[chatbot_interface, conversation_state],
).then(
generate_audio_response,
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
outputs=[audio_output_chat, ref_text_chat],
).then(
lambda: None,
None,
text_input_chat,
)
# Handle clear button or system prompt change and reset conversation
for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
user_operation(
clear_conversation,
outputs=[chatbot_interface, audio_output_chat],
)
# Handle clear button
clear_btn_chat.click(
clear_conversation,
outputs=[chatbot_interface, conversation_state],
)
# Handle system prompt change and reset conversation
system_prompt_chat.change(
update_system_prompt,
inputs=system_prompt_chat,
outputs=[chatbot_interface, conversation_state],
)
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")
with gr.Blocks() as app:
gr.Markdown(
f"""
# E2/F5 TTS
# F5-TTS Demo Space
This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
This is {"a local web UI for [F5-TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
@@ -798,7 +975,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
global tts_model_choice
if new_choice == "Custom": # override in case webpage is refreshed
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
return (
gr.update(visible=True, value=custom_ckpt_path),
gr.update(visible=True, value=custom_vocab_path),
@@ -810,7 +987,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
global tts_model_choice
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
with open(last_used_custom, "w", encoding="utf-8") as f:
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")

View File

@@ -1,5 +1,6 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
from importlib.resources import files
@@ -7,6 +8,7 @@ from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
@@ -14,6 +16,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectro
from f5_tts.model import CFM
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
device = (
"cuda"
if torch.cuda.is_available()
@@ -55,7 +58,8 @@ win_length = model_cfg.model.mel_spec.win_length
n_fft = model_cfg.model.mel_spec.n_fft
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
output_dir = "tests"
@@ -152,7 +156,7 @@ for part in parts_to_edit:
dim=-1,
)
offset = end * target_sample_rate
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
audio = audio.to(device)
edit_mask = edit_mask.to(device)

View File

@@ -4,6 +4,7 @@ import os
import sys
from concurrent.futures import ThreadPoolExecutor
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
@@ -14,6 +15,7 @@ from importlib.resources import files
import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
@@ -27,12 +29,11 @@ from transformers import pipeline
from vocos import Vocos
from f5_tts.model import CFM
from f5_tts.model.utils import (
get_tokenizer,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}
_ref_text_cache = {}
device = (
"cuda"
@@ -44,6 +45,8 @@ device = (
else "cpu"
)
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
# -----------------------------------------
target_sample_rate = 24000
@@ -290,62 +293,74 @@ def remove_silence_edges(audio, silence_threshold=-42):
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
# Compute a hash of the reference audio file
with open(ref_audio_orig, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
show_info("Using cached preprocessed reference audio...")
ref_audio = _ref_audio_cache[audio_hash]
else: # first pass, do preprocess
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
aseg = AudioSegment.from_file(ref_audio_orig)
if clip_short:
# 1. try to find long silence for clipping
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
ref_audio = f.name
aseg.export(temp_path, format="wav")
ref_audio = temp_path
# Compute a hash of the reference audio file
with open(ref_audio, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
# Cache the processed reference audio
_ref_audio_cache[audio_hash] = ref_audio
if not ref_text.strip():
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
global _ref_text_cache
if audio_hash in _ref_text_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
ref_text = _ref_text_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_audio_cache[audio_hash] = ref_text
_ref_text_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
@@ -384,7 +399,7 @@ def infer_process(
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)

View File

@@ -1,9 +1,7 @@
from f5_tts.model.cfm import CFM
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.backbones.dit import DiT
from f5_tts.model.backbones.mmdit import MMDiT
from f5_tts.model.backbones.unett import UNetT
from f5_tts.model.cfm import CFM
from f5_tts.model.trainer import Trainer

View File

@@ -10,19 +10,18 @@ d - dimension
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNorm_Final,
TimestepEmbedding,
precompute_freqs_cis,
get_pos_embed_indices,
)
@@ -30,11 +29,16 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder)
if average_upsampling:
assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True"
if conv_layers > 0:
self.extra_modeling = True
@@ -46,11 +50,46 @@ class TextEmbedding(nn.Module):
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
batch, text_len, text_dim = text.shape
if audio_mask is None:
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
valid_mask = audio_mask & text_mask
audio_lens = audio_mask.sum(dim=1) # [batch]
valid_lens = valid_mask.sum(dim=1) # [batch]
upsampled_text = torch.zeros_like(text)
for i in range(batch):
audio_len = audio_lens[i].item()
valid_len = valid_lens[i].item()
if valid_len == 0:
continue
valid_ind = torch.where(valid_mask[i])[0]
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
base_repeat = audio_len // valid_len
remainder = audio_len % valid_len
indices = []
for j in range(valid_len):
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
indices.extend([j] * repeat_count)
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
upsampled = valid_data[indices] # [audio_len, text_dim]
upsampled_text[i, :audio_len, :] = upsampled
return upsampled_text
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
if self.mask_padding:
text_mask = text == 0
@@ -62,10 +101,7 @@ class TextEmbedding(nn.Module):
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
text = text + self.freqs_cis[:seq_len, :]
# convnextv2 blocks
if self.mask_padding:
@@ -76,6 +112,9 @@ class TextEmbedding(nn.Module):
else:
text = self.text_blocks(text)
if self.average_upsampling:
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
return text
@@ -114,9 +153,12 @@ class DiT(nn.Module):
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
text_embedding_average_upsampling=False,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
):
@@ -126,7 +168,11 @@ class DiT(nn.Module):
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
text_num_embeds,
text_dim,
mask_padding=text_mask_padding,
average_upsampling=text_embedding_average_upsampling,
conv_layers=conv_layers,
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
@@ -146,6 +192,8 @@ class DiT(nn.Module):
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for _ in range(depth)
]
@@ -179,6 +227,48 @@ class DiT(nn.Module):
return ckpt_forward
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
audio_mask: bool["b n"] | None = None, # noqa: F722
):
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
else:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1)
text_embed_list = []
for i in range(batch):
text_embed_i = self.text_embed(
text[i].unsqueeze(0),
seq_lens[i].item(),
drop_text=drop_text,
audio_mask=audio_mask,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
if cache:
if drop_text:
self.text_uncond = text_embed
else:
self.text_cond = text_embed
if cache:
if drop_text:
text_embed = self.text_uncond
else:
text_embed = self.text_cond
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
@@ -188,10 +278,11 @@ class DiT(nn.Module):
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
@@ -199,18 +290,20 @@ class DiT(nn.Module):
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask
)
x_uncond = self.get_input_embed(
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask
)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
x = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask
)
rope = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@@ -11,16 +11,15 @@ from __future__ import annotations
import torch
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNorm_Final,
ConvPositionEmbedding,
MMDiTBlock,
AdaLayerNorm_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -142,26 +141,15 @@ class MMDiT(nn.Module):
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cache:
if drop_text:
if self.text_uncond is None:
@@ -175,6 +163,41 @@ class MMDiT(nn.Module):
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
return x, c
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
c = torch.cat((c_cond, c_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x, c = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@@ -8,24 +8,24 @@ d - dimension
"""
from __future__ import annotations
from typing import Literal
import torch
from torch import nn
import torch.nn.functional as F
from torch import nn
from x_transformers import RMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
ConvPositionEmbedding,
Attention,
AttnProcessor,
ConvNeXtV2Block,
ConvPositionEmbedding,
FeedForward,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
@@ -120,6 +120,8 @@ class UNetT(nn.Module):
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
skip_connect_type: Literal["add", "concat", "none"] = "concat",
):
super().__init__()
@@ -150,7 +152,11 @@ class UNetT(nn.Module):
attn_norm = RMSNorm(dim)
attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -178,26 +184,16 @@ class UNetT(nn.Module):
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
@@ -209,8 +205,41 @@ class UNetT(nn.Module):
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
# postfix time t to input x, [b n d] -> [b n+1 d]
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
if mask is not None:

View File

@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
get_epss_timesteps,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
@@ -92,6 +93,7 @@ class CFM(nn.Module):
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
use_epss=True,
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
@@ -160,16 +162,31 @@ class CFM(nn.Module):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
)
# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
)
return pred
null_pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
cfg_infer=True,
cache=True,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
# noise input
@@ -190,7 +207,10 @@ class CFM(nn.Module):
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
else:
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
@@ -232,10 +252,9 @@ class CFM(nn.Module):
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
if not exists(lens): # if lens not acquired by trainer from collate_fn
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
mask = lens_to_mask(lens, length=seq_len)
# get a random span to mask out for training conditionally
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
@@ -270,10 +289,9 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
)
# flow matching loss

View File

@@ -312,7 +312,7 @@ def collate_fn(batch):
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
@@ -324,7 +324,7 @@ def collate_fn(batch):
return dict(
mel=mel_specs,
mel_lengths=mel_lengths,
mel_lengths=mel_lengths, # records for padding mask
text=text,
text_lengths=text_lengths,
)

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
from __future__ import annotations
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
from f5_tts.model.utils import is_package_available
# raw wav to mel spec
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
@@ -417,9 +420,9 @@ class Attention(nn.Module):
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b n d"] = None, # context c
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
@@ -431,19 +434,30 @@ class Attention(nn.Module):
# Attention processor
if is_package_available("flash_attn"):
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
@@ -479,16 +493,40 @@ class AttnProcessor:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
@@ -514,9 +552,9 @@ class JointAttnProcessor:
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b nt d"] = None, # context c, here text
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
@@ -608,12 +646,27 @@ class JointAttnProcessor:
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="torch", # "torch" or "flash_attn"
attn_mask_enabled=True,
):
super().__init__()
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
def forward(self, timestep: float["b"]):
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d

View File

@@ -19,6 +19,7 @@ from f5_tts.model import CFM
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
from f5_tts.model.utils import default, exists
# trainer
@@ -148,7 +149,7 @@ class Trainer:
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
optimizer_state_dict=self.optimizer.state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
update=update,
@@ -241,7 +242,7 @@ class Trainer:
del checkpoint["model_state_dict"][key]
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
update = checkpoint["update"]

View File

@@ -5,11 +5,10 @@ import random
from collections import defaultdict
from importlib.resources import files
import torch
from torch.nn.utils.rnn import pad_sequence
import jieba
from pypinyin import lazy_pinyin, Style
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
# seed everything
@@ -36,6 +35,16 @@ def default(v, d):
return v if exists(v) else d
def is_package_available(package_name: str) -> bool:
try:
import importlib
package_exists = importlib.util.find_spec(package_name) is not None
return package_exists
except Exception:
return False
# tensor helpers
@@ -190,3 +199,22 @@ def repetition_found(text, length=2, tolerance=10):
if count > tolerance:
return True
return False
# get the empirically pruned step for sampling
def get_epss_timesteps(n, device, dtype):
dt = 1 / 32
predefined_timesteps = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = predefined_timesteps.get(n, [])
if not t:
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
return dt * torch.tensor(t, device=device, dtype=dtype)

View File

@@ -0,0 +1,3 @@
# runtime/triton_trtllm related
model.cache
model_repo/

View File

@@ -1,47 +1,79 @@
## Triton Inference Serving Best Practice for F5-TTS
### Quick Start
Directly launch the service using docker compose.
### Setup
#### Option 1: Quick Start
```sh
# TODO: support F5TTS_v1_Base
MODEL=F5TTS_Base docker compose up
# Directly launch the service using docker compose
MODEL=F5TTS_v1_Base docker compose up
```
### Build Image
Build the docker image from scratch.
#### Option 2: Build from scratch
```sh
# Build the docker image
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
```
### Create Docker Container
```sh
# Create Docker Container
your_mount_dir=/mnt:/mnt
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
```
### Export Models to TensorRT-LLM and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
### Build TensorRT-LLM Engines and Launch Server
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
```sh
bash run.sh 0 4 F5TTS_Base
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
bash run.sh 0 4 F5TTS_v1_Base
```
> [!NOTE]
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`.
> Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
>
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
> [!IMPORTANT]
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
### HTTP Client
```sh
python3 client_http.py
```
### Benchmark using Dataset
### Benchmarking
#### Using Client-Server Mode
```sh
# bash run.sh 5 5 F5TTS_v1_Base
num_task=2
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
```
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
#### Using Offline TRT-LLM Mode
```sh
# bash run.sh 7 7 F5TTS_v1_Base
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
```
| Model | Concurrency | Avg Latency | RTF |
|-------|-------------|----------------|-------|
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
### Benchmark Results
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
| Model | Concurrency | Avg Latency | RTF | Mode |
|---------------------|----------------|-------------|--------|-----------------|
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
### Credits
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
1. [Yuekai Zhang](https://github.com/yuekaizhang)
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)

View File

@@ -0,0 +1,473 @@
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
# 2025 (authors: Yuekai Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
""" Example Usage
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $CKPT_DIR/$model/model_1200000.pt \
--vocab-file $CKPT_DIR/$model/vocab.txt \
--vocoder-trt-engine-path $vocoder_trt_engine_path \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
"""
import argparse
import importlib
import json
import os
import sys
import time
import datasets
import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from vocos import Vocos
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
from f5_tts.eval.utils_eval import padded_mel_batch
from f5_tts.model.modules import get_vocos_mel_spectrogram
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
torch.manual_seed(0)
def get_args():
parser = argparse.ArgumentParser(description="extract speech code")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="huggingface dataset split name",
)
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
parser.add_argument(
"--vocab-file",
required=True,
type=str,
help="vocab file",
)
parser.add_argument(
"--model-path",
required=True,
type=str,
help="model path, to load text embedding",
)
parser.add_argument(
"--tllm-model-dir",
required=True,
type=str,
help="tllm model dir",
)
parser.add_argument(
"--batch-size",
required=True,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
parser.add_argument(
"--vocoder",
default="vocos",
type=str,
help="vocoder name",
)
parser.add_argument(
"--vocoder-trt-engine-path",
default=None,
type=str,
help="vocoder trt engine path",
)
parser.add_argument("--enable-warmup", action="store_true")
parser.add_argument("--remove-input-padding", action="store_true")
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
args = parser.parse_args()
return args
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
if use_perf:
torch.cuda.nvtx.range_push("data_collator")
target_sample_rate = 24000
target_rms = 0.1
(
ids,
ref_rms_list,
ref_mel_list,
ref_mel_len_list,
estimated_reference_target_mel_len,
reference_target_texts_list,
) = (
[],
[],
[],
[],
[],
[],
)
for i, item in enumerate(batch):
item_id, prompt_text, target_text = (
item["id"],
item["prompt_text"],
item["target_text"],
)
ids.append(item_id)
reference_target_texts_list.append(prompt_text + target_text)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
ref_rms_list.append(ref_rms)
if ref_rms < target_rms:
ref_audio_org = ref_audio_org * target_rms / ref_rms
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
if use_perf:
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
ref_audio = ref_audio.to("cuda")
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
if use_perf:
torch.cuda.nvtx.range_pop()
ref_mel_len = ref_mel.shape[-1]
assert ref_mel.shape[0] == 100
ref_mel_list.append(ref_mel)
ref_mel_len_list.append(ref_mel_len)
estimated_reference_target_mel_len.append(
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
)
ref_mel_batch = padded_mel_batch(ref_mel_list)
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
if use_perf:
torch.cuda.nvtx.range_pop()
return {
"ids": ids,
"ref_rms_list": ref_rms_list,
"ref_mel_batch": ref_mel_batch,
"ref_mel_len_batch": ref_mel_len_batch,
"text_pad_sequence": text_pad_sequence,
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
# Initialize process group with explicit device IDs
dist.init_process_group(
"nccl",
)
return world_size, local_rank, rank
def load_vocoder(
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
):
if vocoder_name == "vocos":
if vocoder_trt_engine_path is not None:
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
else:
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
if is_local:
print(f"Load vocos from local path {local_path}")
config_path = f"{local_path}/config.yaml"
model_path = f"{local_path}/pytorch_model.bin"
else:
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
repo_id = "charactr/vocos-mel-24khz"
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
vocoder = Vocos.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
from vocos.feature_extractors import EncodecFeatures
if isinstance(vocoder.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
vocoder.load_state_dict(state_dict)
vocoder = vocoder.eval().to(device)
elif vocoder_name == "bigvgan":
raise NotImplementedError("BigVGAN is not implemented yet")
return vocoder
class VocosTensorRT:
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f"Loading vocoder engine from {engine_path}")
self.engine_path = engine_path
with open(engine_path, "rb") as f:
engine_buffer = f.read()
self.session = Session.from_serialized_engine(engine_buffer)
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
def decode(self, mels):
mels = mels.contiguous()
inputs = {"mel": mels}
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
outputs = {
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
}
ok = self.session.run(inputs, outputs, self.stream)
assert ok, "Runtime execution failed for vae session"
samples = outputs["waveform"]
return samples
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
tllm_model_dir = args.tllm_model_dir
with open(os.path.join(tllm_model_dir, "config.json")) as f:
tllm_model_config = json.load(f)
if args.backend_type == "trt":
model = F5TTS(
tllm_model_config,
debug_mode=False,
tllm_model_dir=tllm_model_dir,
model_path=args.model_path,
vocab_size=vocab_size,
)
elif args.backend_type == "pytorch":
from f5_tts.infer.utils_infer import load_model
from f5_tts.model import DiT
pretrained_config = tllm_model_config["pretrained_config"]
pt_model_config = dict(
dim=pretrained_config["hidden_size"],
depth=pretrained_config["num_hidden_layers"],
heads=pretrained_config["num_attention_heads"],
ff_mult=pretrained_config["ff_mult"],
text_dim=pretrained_config["text_dim"],
text_mask_padding=pretrained_config["text_mask_padding"],
conv_layers=pretrained_config["conv_layers"],
pe_attn_head=pretrained_config["pe_attn_head"],
)
model = load_model(DiT, pt_model_config, args.model_path)
vocoder = load_vocoder(
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
)
dataset = load_dataset(
"yuekai/seed_tts",
split=args.split_name,
trust_remote_code=True,
)
def add_estimated_duration(example):
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
estimated_duration = prompt_audio_len * scale_factor
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
return example
dataset = dataset.map(add_estimated_duration)
dataset = dataset.sort("estimated_duration", reverse=True)
if args.use_perf:
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
dataset = datasets.concatenate_datasets(dataset_list_short)
if world_size > 1:
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
else:
# This would disable shuffling
sampler = None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
)
total_steps = len(dataset)
if args.enable_warmup:
for batch in dataloader:
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.backend_type == "trt":
_ = model.sample(
text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
decoding_time = 0
vocoder_time = 0
total_duration = 0
if args.use_perf:
torch.cuda.cudart().cudaProfilerStart()
total_decoding_time = time.time()
for batch in dataloader:
if args.use_perf:
torch.cuda.nvtx.range_push("data sample")
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
text_pad_seq = batch["text_pad_sequence"].to(device)
total_mel_lens = batch["estimated_reference_target_mel_len"]
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
if args.use_perf:
torch.cuda.nvtx.range_pop()
if args.backend_type == "trt":
generated, cost_time = model.sample(
text_pad_seq,
cond_pad_seq,
ref_mel_lens,
total_mel_lens,
remove_input_padding=args.remove_input_padding,
use_perf=args.use_perf,
)
elif args.backend_type == "pytorch":
total_mel_lens = torch.tensor(total_mel_lens, device=device)
with torch.inference_mode():
start_time = time.time()
generated, _ = model.sample(
cond=ref_mels,
text=text_pad_seq,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=32,
cfg_strength=2.0,
sway_sampling_coef=-1,
)
cost_time = time.time() - start_time
decoding_time += cost_time
vocoder_start_time = time.time()
target_rms = 0.1
target_sample_rate = 24000
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
if args.vocoder == "vocos":
if args.use_perf:
torch.cuda.nvtx.range_push("vocoder decode")
generated_wave = vocoder.decode(gen_mel_spec).cpu()
if args.use_perf:
torch.cuda.nvtx.range_pop()
else:
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
if batch["ref_rms_list"][i] < target_rms:
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
utt = batch["ids"][i]
torchaudio.save(
f"{args.output_dir}/{utt}.wav",
generated_wave,
target_sample_rate,
)
total_duration += generated_wave.shape[1] / target_sample_rate
vocoder_time += time.time() - vocoder_start_time
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
total_decoding_time = time.time() - total_decoding_time
if rank == 0:
progress_bar.close()
rtf = total_decoding_time / total_duration
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
s += f"batch size: {args.batch_size}\n"
print(s)
with open(f"{args.output_dir}/rtf.txt", "w") as f:
f.write(s)
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@@ -30,21 +30,11 @@ python3 client_grpc.py \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import os
import time
import types
@@ -177,8 +167,7 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
help="triton model_repo module name to request",
)
parser.add_argument(
@@ -207,7 +196,7 @@ def get_args():
"--log-dir",
type=str,
required=False,
default="./tmp",
default="./tests/client_grpc",
help="log directory",
)
@@ -221,8 +210,8 @@ def get_args():
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
@@ -231,8 +220,7 @@ def load_audio(wav_path, target_sample_rate=16000):
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate
@@ -245,7 +233,7 @@ async def send(
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
save_sample_rate: int = 24000,
):
total_duration = 0.0
latency_data = []
@@ -255,7 +243,7 @@ async def send(
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
@@ -311,8 +299,9 @@ async def send(
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, estimated_target_duration))
total_duration += estimated_target_duration
actual_duration = len(audio) / save_sample_rate
latency_data.append((end, actual_duration))
total_duration += actual_duration
return total_duration, latency_data
@@ -417,7 +406,7 @@ async def main():
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
save_sample_rate=24000,
)
)
tasks.append(task)

View File

@@ -23,10 +23,12 @@
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import os
import numpy as np
import requests
import soundfile as sf
import numpy as np
import argparse
def get_args():
@@ -64,33 +66,32 @@ def get_args():
"--model-name",
type=str,
default="f5_tts",
choices=["f5_tts", "spark_tts"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
default="tests/client_http.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
samples,
waveform,
reference_text,
target_text,
sample_rate=16000,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
lengths = np.array([[len(samples)]], dtype=np.int32)
samples = samples.reshape(1, -1).astype(np.float32)
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
waveform = waveform.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
{
"name": "reference_wav_len",
"shape": lengths.shape,
@@ -105,19 +106,18 @@ def prepare_request(
return data
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
samples, sample_rate = sf.read(wav_path)
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
samples = resample(samples, num_samples)
return samples, target_sample_rate
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
return waveform, target_sample_rate
if __name__ == "__main__":
@@ -127,11 +127,11 @@ if __name__ == "__main__":
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
waveform, sr = load_audio(args.reference_audio)
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
waveform = np.array(waveform, dtype=np.float32)
data = prepare_request(waveform, args.reference_text, args.target_text)
rsp = requests.post(
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
@@ -139,4 +139,5 @@ if __name__ == "__main__":
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
sf.write(args.output_audio, audio, 24000, "PCM_16")

View File

@@ -1,18 +1,18 @@
import tensorrt as trt
import os
import math
import os
import time
from typing import List, Optional
from functools import wraps
from typing import List, Optional
import tensorrt as trt
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session
from torch.nn.utils.rnn import pad_sequence
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
@@ -33,26 +33,35 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text):
# only keep tensors with value not -1
text_mask = text != -1
text_pad_cut_off_index = text_mask.sum(dim=1).max()
def forward(self, text, seq_len, drop_text=False):
text = text + 1
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
if self.mask_padding:
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
for block in self.text_blocks:
text = block(text)
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
else:
text = self.text_blocks(text)
text = text[:, :text_pad_cut_off_index]
text = self.text_embed(text)
text = text + self.freqs_cis[: text.shape[1], :]
for block in self.text_blocks:
text = block(text)
# padding text to the original length
# text shape: B,seq_len,C
# pad at the second dimension
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
return text
@@ -113,20 +122,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def load_checkpoint(ckpt_path, use_ema=True):
checkpoint = torch.load(ckpt_path, weights_only=True)
def get_text_embed_dict(ckpt_path, use_ema=True):
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
dict_state = checkpoint["model_state_dict"]
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model_params = checkpoint["model_state_dict"]
text_embed_dict = {}
for key in dict_state.keys():
for key in model_params.keys():
# transformer.text_embed.text_embed.weight -> text_embed.weight
if "text_embed" in key:
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
return text_embed_dict
@@ -195,18 +217,16 @@ class F5TTS(object):
self.max_mel_len = 4096
self.text_embedding = TextEmbedding(
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
text_num_embeds=vocab_size,
text_dim=config["pretrained_config"]["text_dim"],
mask_padding=config["pretrained_config"]["text_mask_padding"],
conv_layers=config["pretrained_config"]["conv_layers"],
precompute_max_pos=self.max_mel_len,
).to(self.device)
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
# self.max_mel_len = 3000
self.head_dim = 64
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
self.head_dim = config["pretrained_config"]["dim_head"]
self.base_rescale_factor = 1.0
self.interpolation_factor = 1.0
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
@@ -215,14 +235,23 @@ class F5TTS(object):
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
self.rope_cos = self.freqs.cos().half()
self.rope_sin = self.freqs.sin().half()
self.nfe_steps = 16
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
self.nfe_steps = 32
epss = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
time_step = 1 - torch.cos(torch.pi * t / 2)
delta_t = torch.diff(time_step)
# WAR: hard coding 256 here
tmp_dim = 256
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
half_dim = tmp_dim // 2
freq_embed_dim = 256 # Warning: hard coding 256 here
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
half_dim = freq_embed_dim // 2
emb_factor = math.log(10000) / (half_dim - 1)
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
for i in range(self.nfe_steps):
@@ -345,7 +374,7 @@ class F5TTS(object):
def sample(
self,
text_pad_sequence: torch.Tensor,
ref_mel_batch: torch.Tensor,
cond_pad_sequence: torch.Tensor,
ref_mel_len_batch: torch.Tensor,
estimated_reference_target_mel_len: List[int],
remove_input_padding: bool = False,
@@ -354,26 +383,43 @@ class F5TTS(object):
if use_perf:
torch.cuda.nvtx.range_push("text embedding")
batch = text_pad_sequence.shape[0]
max_seq_len = ref_mel_batch.shape[1]
max_seq_len = cond_pad_sequence.shape[1]
text_pad_sequence_drop = torch.cat(
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
# get text_embed one by one to avoid misalignment
text_and_drop_embedding_list = []
for i in range(batch):
text_embedding_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=False,
)
text_embedding_drop_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=True,
)
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
# pad separately computed text_embed to form batch with max_seq_len
text_and_drop_embedding = pad_sequence(
text_and_drop_embedding_list,
batch_first=True,
padding_value=0,
)
text_embedding = text_and_drop_embedding[0::2]
text_embedding_drop = text_and_drop_embedding[1::2]
text_embedding_drop_list = []
for i in range(batch + 1):
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
text_embedding = text_embedding_drop_condition[:-1]
# text_embedding_drop B,T,C batch should be the same
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
noise = torch.randn_like(ref_mel_batch).to(self.device)
noise = torch.randn_like(cond_pad_sequence).to(self.device)
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
cat_mel_text = torch.cat(
(
cond_pad_sequence,
text_embedding,
),
dim=-1,
)
cat_mel_text_drop = torch.cat(
(
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),

View File

@@ -24,16 +24,16 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.dlpack import from_dlpack, to_dlpack
import torchaudio
import jieba
import triton_python_backend_utils as pb_utils
from pypinyin import Style, lazy_pinyin
import os
import jieba
import torch
import torchaudio
import triton_python_backend_utils as pb_utils
from f5_tts_trtllm import F5TTS
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
from torch.utils.dlpack import from_dlpack, to_dlpack
def get_tokenizer(vocab_file_path: str):
@@ -98,7 +98,8 @@ def list_str_to_idx(
padding_value=-1,
): # noqa: F722
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
return list_idx_tensors
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
class TritonPythonModel:
@@ -106,13 +107,12 @@ class TritonPythonModel:
self.use_perf = True
self.device = torch.device("cuda")
self.target_audio_sample_rate = 24000
self.target_rms = 0.15 # target rms for audio
self.target_rms = 0.1 # least rms when inference, normalize to if lower
self.n_fft = 1024
self.win_length = 1024
self.hop_length = 256
self.n_mel_channels = 100
self.max_mel_len = 3000
self.head_dim = 64
self.max_mel_len = 4096
parameters = json.loads(args["model_config"])["parameters"]
for key, value in parameters.items():
@@ -180,7 +180,8 @@ class TritonPythonModel:
reference_target_texts_list,
estimated_reference_target_mel_len,
reference_mel_len,
) = [], [], [], [], []
reference_rms_list,
) = [], [], [], [], [], []
mel_features_list = []
if self.use_perf:
torch.cuda.nvtx.range_push("preprocess")
@@ -207,6 +208,7 @@ class TritonPythonModel:
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
if ref_rms < self.target_rms:
wav = wav * self.target_rms / ref_rms
reference_rms_list.append(ref_rms)
if self.reference_sample_rate != self.target_audio_sample_rate:
wav = self.resampler(wav)
wav = wav.to(self.device)
@@ -219,13 +221,15 @@ class TritonPythonModel:
reference_mel_len.append(mel_features.shape[1])
estimated_reference_target_mel_len.append(
int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))
int(
mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
)
)
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
@@ -234,15 +238,6 @@ class TritonPythonModel:
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
for i, item in enumerate(text_pad_sequence):
text_pad_sequence[i] = F.pad(
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
)
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
text_pad_sequence = F.pad(
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
)
if self.use_perf:
torch.cuda.nvtx.range_pop()
@@ -259,13 +254,12 @@ class TritonPythonModel:
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
ref_mel_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < self.target_rms:
audio = audio * self.target_rms / rms
if reference_rms_list[i] < self.target_rms:
audio = audio * reference_rms_list[i] / self.target_rms
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])

View File

@@ -33,7 +33,7 @@ parameters [
},
{
key: "reference_audio_sample_rate",
value: {string_value:"16000"}
value: {string_value:"24000"}
},
{
key: "vocoder",

View File

@@ -34,6 +34,7 @@ from .deepseek_v2.model import DeepseekV2ForCausalLM
from .dit.model import DiT
from .eagle.model import EagleForCausalLM
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
from .f5tts.model import F5TTS
from .falcon.config import FalconConfig
from .falcon.model import FalconForCausalLM, FalconModel
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
@@ -54,12 +55,12 @@ from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodi
from .mpt.model import MPTForCausalLM, MPTModel
from .nemotron_nas.model import DeciLMForCausalLM
from .opt.model import OPTForCausalLM, OPTModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .phi.model import PhiForCausalLM, PhiModel
from .phi3.model import Phi3ForCausalLM, Phi3Model
from .qwen.model import QWenForCausalLM
from .recurrentgemma.model import RecurrentGemmaForCausalLM
from .redrafter.model import ReDrafterForCausalLM
from .f5tts.model import F5TTS
__all__ = [
"BertModel",

View File

@@ -1,23 +1,20 @@
from __future__ import annotations
import sys
import os
import sys
from collections import OrderedDict
import tensorrt as trt
from collections import OrderedDict
from tensorrt_llm._common import default_net
from ..._utils import str_dtype_to_trt
from ...functional import Tensor, concat
from ...layers import Linear
from ...module import Module, ModuleList
from ...plugin import current_all_reduce_helper
from ..modeling_utils import PretrainedConfig, PretrainedModel
from ...functional import Tensor, concat
from ...module import Module, ModuleList
from tensorrt_llm._common import default_net
from ...layers import Linear
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
from .modules import (
TimestepEmbedding,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
)
current_file_path = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_file_path)
@@ -53,6 +50,7 @@ class F5TTS(PretrainedModel):
dim_head=config.dim_head,
ff_mult=config.ff_mult,
dropout=config.dropout,
pe_attn_head=config.pe_attn_head,
)
for _ in range(self.depth)
]
@@ -82,13 +80,12 @@ class F5TTS(PretrainedModel):
def prepare_inputs(self, **kwargs):
max_batch_size = kwargs["max_batch_size"]
batch_size_range = [2, 2, max_batch_size]
mel_size = 100
max_seq_len = 3000
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
hidden_size = 512
concat_feature_dim = mel_size + hidden_size
freq_embed_dim = 256
head_dim = 64
mel_size = self.config.mel_dim
max_seq_len = 3000 # 4096
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
concat_feature_dim = mel_size + self.config.text_dim
freq_embed_dim = 256 # Warning: hard coding 256 here
head_dim = self.config.dim_head
mapping = self.config.mapping
if mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(mapping, 1)

View File

@@ -3,33 +3,35 @@ from __future__ import annotations
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import numpy as np
from tensorrt_llm._common import default_net
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
from ...functional import (
Tensor,
bert_attention,
cast,
chunk,
concat,
constant,
expand,
expand_dims,
expand_dims_like,
expand_mask,
gelu,
matmul,
permute,
shape,
silu,
slice,
permute,
expand_mask,
expand_dims_like,
unsqueeze,
matmul,
softmax,
squeeze,
cast,
gelu,
unsqueeze,
view,
)
from ...functional import expand_dims, view, bert_attention
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
from ...module import Module
@@ -225,29 +227,52 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
return out
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
if default_net().plugin_config.remove_input_padding:
rot_dim = shape(rope_cos, -1) # 64
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
end_dim = shape(x, -1) - shape(rope_cos, -1)
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
else:
rot_dim = shape(rope_cos, 2) # 64
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
end_dim = shape(x, 2) - shape(rope_cos, 2)
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
full_dim = x.size(-1)
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
if pe_attn_head is None:
pe_attn_head = full_dim // head_dim
rotated_dim = head_dim * pe_attn_head
rotated_and_unrotated_list = []
if default_net().plugin_config.remove_input_padding: # for [N, D] input
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
rotated_and_unrotated_list.append(x_unrotated)
else: # for [B, N, D] input
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
for i in range(pe_attn_head):
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
rotated_and_unrotated_list.append(x_rotated_i)
new_t_unrotated_shape = concat(
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
) # (2, -1, 1024 - 64 * pe_attn_head)
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
rotated_and_unrotated_list.append(x_unrotated)
out = concat(rotated_and_unrotated_list, dim=-1)
return out
class AttnProcessor:
def __init__(self):
pass
def __init__(
self,
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
):
self.pe_attn_head = pe_attn_head
def __call__(
self,
@@ -263,8 +288,8 @@ class AttnProcessor:
key = attn.to_k(x)
value = attn.to_v(x)
# k,v,q all (2,1226,1024)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
# attention
inner_dim = key.shape[-1]
@@ -352,12 +377,12 @@ class AttnProcessor:
# DiT Block
class DiTBlock(Module):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
processor=AttnProcessor(pe_attn_head=pe_attn_head),
dim=dim,
heads=heads,
dim_head=dim_head,

View File

@@ -1,64 +1,66 @@
stage=$1
stop_stage=$2
model=$3 # F5TTS_Base
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
if [ -z "$model" ]; then
echo "Model is none"
exit 1
model=F5TTS_v1_Base
fi
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
export CUDA_VISIBLE_DEVICES=0
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
CKPT_DIR=../../../../ckpts
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
vocoder_trt_engine_path=vocos_vocoder.plan
model_repo=./model_repo
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
MODEL_REPO=./model_repo
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Downloading f5 tts from huggingface"
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
echo "Downloading F5-TTS from huggingface"
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
fi
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
vocab_file=$CKPT_DIR/$model/vocab.txt
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
echo "Converting checkpoint"
python3 ./scripts/convert_checkpoint.py \
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
python3 scripts/convert_checkpoint.py \
--pytorch_ckpt $ckpt_file \
--output_dir $TRTLLM_CKPT_DIR --model_name $model
python_package_path=/usr/local/lib/python3.12/dist-packages
cp -r patch/* $python_package_path/tensorrt_llm/models
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
--max_batch_size 8 \
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
echo "Exporting vocos vocoder"
onnx_vocoder_path=vocos_vocoder.onnx
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
echo "Building triton server"
rm -r $model_repo
cp -r ./model_repo_f5_tts $model_repo
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
rm -r $MODEL_REPO
cp -r ./model_repo_f5_tts $MODEL_REPO
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Starting triton server"
tritonserver --model-repository=$model_repo
tritonserver --model-repository=$MODEL_REPO
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
echo "Testing triton server"
num_task=1
log_dir=./log_concurrent_tasks_${num_task}
split_name=wenetspeech4tts
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
rm -r $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
@@ -66,5 +68,45 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
audio=../../infer/examples/basic/basic_ref_en.wav
reference_text="Some call me nature, others call me mother nature."
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "TRT-LLM: offline decoding benchmark test"
batch_size=1
split_name=wenetspeech4tts
backend_type=trt
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--enable-warmup \
--split-name $split_name \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
echo "Native Pytorch: offline decoding benchmark test"
if ! python3 -c "import f5_tts" &> /dev/null; then
pip install -e ../../../../
fi
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
split_name=wenetspeech4tts
backend_type=pytorch
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
rm -r $log_dir
torchrun --nproc_per_node=1 \
benchmark.py --output-dir $log_dir \
--batch-size $batch_size \
--split-name $split_name \
--enable-warmup \
--model-path $ckpt_file \
--vocab-file $vocab_file \
--backend-type $backend_type \
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
fi

View File

@@ -40,6 +40,7 @@ import torch as th
import torch.nn.functional as F
from scipy.signal import check_COLA, get_window
support_clp_op = None
if th.__version__ >= "1.7.0":
from torch.fft import rfft as fft

View File

@@ -8,7 +8,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import safetensors.torch
import torch
from tensorrt_llm import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
@@ -24,168 +23,12 @@ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
return split_v.contiguous()
FACEBOOK_DIT_NAME_MAPPING = {
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
}
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Base",
choices=[
"F5TTS_Base",
],
) # TODO: support F5TTS_v1_Base
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
parser.add_argument(
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--cfg_scale", type=float, default=4.0)
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
@@ -194,33 +37,119 @@ def parse_arguments():
parser.add_argument(
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
)
parser.add_argument(
"--model_name",
type=str,
default="F5TTS_Custom",
choices=[
"F5TTS_v1_Base",
"F5TTS_Base",
"F5TTS_v1_Small",
"F5TTS_Small",
], # if set, overwrite the below hyperparams
)
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
parser.add_argument(
"--text_mask_padding",
type=lambda x: x.lower() == "true",
choices=[True, False],
default=True,
help="Whether apply padding mask for conv layers in text encoder",
)
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
args = parser.parse_args()
# overwrite if --model_name ordered
if args.model_name == "F5TTS_v1_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Base":
args.hidden_size = 1024
args.depth = 22
args.num_heads = 16
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
elif args.model_name == "F5TTS_v1_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = True
args.conv_layers = 4
args.pe_attn_head = None
elif args.model_name == "F5TTS_Small":
args.hidden_size = 768
args.depth = 18
args.num_heads = 12
args.dim_head = 64
args.ff_mult = 2
args.text_dim = 512
args.text_mask_padding = False
args.conv_layers = 4
args.pe_attn_head = 1
return args
def convert_timm_dit(args, mapping, dtype="float32"):
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
weights = {}
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
tensor_parallel = mapping.tp_size
model_params = dict(torch.load(args.timm_ckpt))
model_params = {
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
ckpt_path = args.pytorch_ckpt
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
model_params = load_file(ckpt_path)
else:
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
prefix = "ema_model.transformer." if use_ema else "transformer."
if any(k.startswith(prefix) for k in model_params.keys()):
model_params = {
key[len(prefix) :] if key.startswith(prefix) else key: value
for key, value in model_params.items()
if key.startswith(prefix)
}
pytorch_to_trtllm_name = {
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
}
prefix = "ema_model.transformer."
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
def get_trtllm_name(timm_name):
for k, v in timm_to_trtllm_name.items():
m = re.match(k, timm_name)
if m is not None:
if "*" in v:
v = v.replace("*", m.groups()[0])
return v
return timm_name
def get_trtllm_name(pytorch_name):
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
if trtllm_name_if_matched != pytorch_name:
return trtllm_name_if_matched
return pytorch_name
weights = dict()
for name, param in model_params.items():
@@ -231,7 +160,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
assert len(weights) == len(model_params)
# new_prefix = 'f5_transformer.'
# new_prefix = "f5_transformer."
new_prefix = ""
weights = {new_prefix + key: value for key, value in weights.items()}
import math
@@ -273,19 +202,19 @@ def save_config(args):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
config = {
"architecture": "F5TTS",
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
"dtype": args.dtype,
"hidden_size": 1024,
"num_hidden_layers": 22,
"num_attention_heads": 16,
"dim_head": 64,
"dropout": 0.1,
"ff_mult": 2,
"hidden_size": args.hidden_size,
"num_hidden_layers": args.depth,
"num_attention_heads": args.num_heads,
"dim_head": args.dim_head,
"dropout": 0.0, # inference-only
"ff_mult": args.ff_mult,
"mel_dim": 100,
"text_num_embeds": 256,
"text_dim": 512,
"conv_layers": 4,
"long_skip_connection": False,
"text_dim": args.text_dim,
"text_mask_padding": args.text_mask_padding,
"conv_layers": args.conv_layers,
"pe_attn_head": args.pe_attn_head,
"mapping": {
"world_size": args.cp_size * args.tp_size * args.pp_size,
"cp_size": args.cp_size,
@@ -297,7 +226,7 @@ def save_config(args):
config["quantization"] = {
"quant_algo": "FP8",
# TODO: add support for exclude modules.
# 'exclude_modules': "*final_layer*",
# "exclude_modules": "*final_layer*",
}
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
@@ -316,7 +245,7 @@ def covert_and_save(args, rank):
pp_size=args.pp_size,
)
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
@@ -345,9 +274,9 @@ def main():
assert args.pp_size == 1, "PP is not supported yet."
tik = time.time()
if args.timm_ckpt is None:
if args.pytorch_ckpt is None:
return
print("start execute")
print("Start execute")
execute(args.workers, [covert_and_save] * world_size, args)
tok = time.time()

View File

@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from conv_stft import STFT
from huggingface_hub import hf_hub_download
from vocos import Vocos
import argparse
opset_version = 17

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Manual installation of TensorRT, in case not using NVIDIA NGC:
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
ONNX_PATH=$1
@@ -28,7 +30,7 @@ MAX_BATCH_SIZE=8
MIN_INPUT_LENGTH=1
OPT_INPUT_LENGTH=1000
MAX_INPUT_LENGTH=3000
MAX_INPUT_LENGTH=3000 # 4096
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
@@ -40,4 +42,3 @@ ${TRTEXEC} \
--maxShapes="mel:${MEL_MAX_SHAPE}" \
--onnx=${ONNX_PATH} \
--saveEngine=${ENGINE_PATH}

View File

@@ -1,12 +1,13 @@
import sys
import os
import sys
sys.path.append(os.getcwd())
from f5_tts.model import CFM, DiT
import torch
import thop
import torch
from f5_tts.model import CFM, DiT
""" ~155M """

View File

@@ -1,10 +1,12 @@
import socket
import asyncio
import pyaudio
import numpy as np
import logging
import socket
import time
import numpy as np
import pyaudio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,6 @@
import argparse
import gc
import logging
import numpy as np
import queue
import socket
import struct
@@ -10,6 +9,7 @@ import traceback
import wave
from importlib.resources import files
import numpy as np
import torch
import torchaudio
from huggingface_hub import hf_hub_download
@@ -18,12 +18,13 @@ from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
chunk_text,
preprocess_ref_audio_text,
load_vocoder,
load_model,
infer_batch_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

View File

@@ -1,5 +1,11 @@
# Training
Check your FFmpeg installation:
```bash
ffmpeg -version
```
If not found, install it first (or skip assuming you know of other backends available).
## Prepare Dataset
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.

View File

@@ -1,12 +1,13 @@
import os
import sys
import signal
import subprocess # For invoking ffprobe
import shutil
import concurrent.futures
import multiprocessing
import os
import shutil
import signal
import subprocess # For invoking ffprobe
import sys
from contextlib import contextmanager
sys.path.append(os.getcwd())
import argparse
@@ -16,12 +17,10 @@ from importlib.resources import files
from pathlib import Path
import torchaudio
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
@@ -209,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
out_dir.mkdir(exist_ok=True, parents=True)
print(f"\nSaving to {out_dir} ...")
# Save dataset with improved batch size for better I/O performance
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# Save durations to JSON
dur_json_path = out_dir / "duration.json"

View File

@@ -7,20 +7,18 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import (
repetition_found,
convert_char_to_pinyin,
)
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
out_zh = {
@@ -183,6 +181,7 @@ def main():
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:

View File

@@ -0,0 +1,95 @@
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
# prepares Emilia dataset with the new format w/ Emilia-YODAS
import json
import os
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
from f5_tts.model.utils import repetition_found
# Define filters for exclusion
out_en = set()
en_filters = ["ا", "", ""]
def process_audio_directory(audio_dir):
sub_result, durations, vocab_set = [], [], set()
bad_case_en = 0
for file in audio_dir.iterdir():
if file.suffix == ".json":
with open(file, "r") as f:
obj = json.load(f)
text = obj["text"]
if any(f in text for f in en_filters) or repetition_found(text, length=4):
bad_case_en += 1
continue
duration = obj["duration"]
audio_file = file.with_suffix(".mp3")
if audio_file.exists():
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
durations.append(duration)
vocab_set.update(list(text))
return sub_result, durations, vocab_set, bad_case_en
def main():
assert tokenizer in ["pinyin", "char"]
result, duration_list, text_vocab_set = [], [], set()
total_bad_case_en = 0
executor = ProcessPoolExecutor(max_workers=max_workers)
futures = []
dataset_path = Path(dataset_dir)
for sub_dir in dataset_path.iterdir():
if sub_dir.is_dir():
futures.append(executor.submit(process_audio_directory, sub_dir))
for future in tqdm(futures, total=len(futures)):
sub_result, durations, vocab_set, bad_case_en = future.result()
result.extend(sub_result)
duration_list.extend(durations)
text_vocab_set.update(vocab_set)
total_bad_case_en += bad_case_en
executor.shutdown()
if not os.path.exists(f"{save_dir}"):
os.makedirs(f"{save_dir}")
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
with open(f"{save_dir}/vocab.txt", "w") as f:
for vocab in sorted(text_vocab_set):
f.write(vocab + "\n")
print(f"For {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
print(f"Bad en transcription case: {total_bad_case_en}\n")
if __name__ == "__main__":
max_workers = 32
tokenizer = "char"
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
dataset_name = f"Emilia_EN_{tokenizer}"
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
main()

View File

@@ -1,15 +1,17 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def deal_with_audio_dir(audio_dir):
@@ -60,6 +62,7 @@ def main():
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:

View File

@@ -1,14 +1,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from importlib.resources import files
from pathlib import Path
from tqdm import tqdm
import soundfile as sf
from datasets.arrow_writer import ArrowWriter
from tqdm import tqdm
def main():
@@ -37,6 +39,7 @@ def main():
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:

View File

@@ -4,15 +4,16 @@
import os
import sys
sys.path.append(os.getcwd())
import json
from concurrent.futures import ProcessPoolExecutor
from importlib.resources import files
from tqdm import tqdm
import torchaudio
from datasets import Dataset
from tqdm import tqdm
from f5_tts.model.utils import convert_char_to_pinyin

View File

@@ -5,9 +5,9 @@ from importlib.resources import files
from cached_path import cached_path
from f5_tts.model import CFM, UNetT, DiT, Trainer
from f5_tts.model.utils import get_tokenizer
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
# -------------------------- Dataset Settings --------------------------- #

View File

@@ -1,14 +1,12 @@
import gc
import json
import numpy as np
import os
import platform
import psutil
import queue
import random
import re
import signal
import shutil
import signal
import subprocess
import sys
import tempfile
@@ -16,21 +14,23 @@ import threading
import time
from glob import glob
from importlib.resources import files
from scipy.io import wavfile
import click
import gradio as gr
import librosa
import numpy as np
import psutil
import torch
import torchaudio
from cached_path import cached_path
from datasets import Dataset as Dataset_
from datasets.arrow_writer import ArrowWriter
from safetensors.torch import load_file, save_file
from scipy.io import wavfile
from f5_tts.api import F5TTS
from f5_tts.model.utils import convert_char_to_pinyin
from f5_tts.infer.utils_infer import transcribe
from f5_tts.model.utils import convert_char_to_pinyin
training_process = None
@@ -138,6 +138,8 @@ def load_settings(project_name):
"logger": "none",
"bnb_optimizer": False,
}
if device == "mps":
default_settings["mixed_precision"] = "none"
# Load settings from file if it exists
if os.path.isfile(file_setting):
@@ -176,50 +178,12 @@ def get_audio_duration(audio_path):
return audio.shape[1] / sample_rate
def clear_text(text):
"""Clean and prepare text by lowering the case and stripping whitespace."""
return text.lower().strip()
def get_rms(
y,
frame_length=2048,
hop_length=512,
pad_mode="constant",
): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
padding = (int(frame_length // 2), int(frame_length // 2))
y = np.pad(y, padding, mode=pad_mode)
axis = -1
# put our new within-frame axis at the end for now
out_strides = y.strides + tuple([y.strides[axis]])
# Reduce the shape on the framing axis
x_shape_trimmed = list(y.shape)
x_shape_trimmed[axis] -= frame_length - 1
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
if axis < 0:
target_axis = axis - 1
else:
target_axis = axis + 1
xw = np.moveaxis(xw, -1, target_axis)
# Downsample along the target axis
slices = [slice(None)] * xw.ndim
slices[axis] = slice(0, None, hop_length)
x = xw[tuple(slices)]
# Calculate power
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
return np.sqrt(power)
class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
def __init__(
self,
sr: int,
threshold: float = -40.0,
min_length: int = 2000,
min_length: int = 20000, # 20 seconds
min_interval: int = 300,
hop_size: int = 20,
max_sil_kept: int = 2000,
@@ -250,7 +214,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
samples = waveform
if samples.shape[0] <= self.min_length:
return [waveform]
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
sil_tags = []
silence_start = None
clip_start = 0
@@ -304,8 +268,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
silence_end = min(total_frames, silence_start + self.max_sil_kept)
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
sil_tags.append((pos, total_frames + 1))
# Apply and return slices.
####音频+起始时间+终止时间
# Apply and return slices: [chunk, start, end]
if len(sil_tags) == 0:
return [[waveform, 0, int(total_frames * self.hop_size)]]
else:
@@ -432,7 +395,7 @@ def start_training(
fp16 = ""
cmd = (
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
f" --learning_rate {learning_rate}"
f" --batch_size_per_gpu {batch_size_per_gpu}"
f" --batch_size_type {batch_size_type}"
@@ -451,7 +414,7 @@ def start_training(
cmd += " --finetune"
if file_checkpoint_train != "":
cmd += f" --pretrain {file_checkpoint_train}"
cmd += f' --pretrain "{file_checkpoint_train}"'
if tokenizer_file != "":
cmd += f" --tokenizer_path {tokenizer_file}"
@@ -705,7 +668,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
try:
text = transcribe(file_segment, language)
text = text.lower().strip().replace('"', "")
text = text.strip()
data += f"{name_segment}|{text}\n"
@@ -814,7 +777,7 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
error_files.append([file_audio, "very short text length 3"])
continue
text = clear_text(text)
text = text.strip()
text = convert_char_to_pinyin([text], polyphone=True)[0]
audio_path_list.append(file_audio)
@@ -833,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
min_second = round(min(duration_list), 2)
max_second = round(max(duration_list), 2)
with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
with ArrowWriter(path=file_raw) as writer:
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
writer.write(line)
writer.finalize()
with open(file_duration, "w") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
@@ -1097,7 +1061,7 @@ def vocab_extend(project_name, symbols, model_type):
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
def vocab_check(project_name):
def vocab_check(project_name, tokenizer_type):
name_project = project_name
path_project = os.path.join(path_data, name_project)
@@ -1125,7 +1089,9 @@ def vocab_check(project_name):
if len(sp) != 2:
continue
text = sp[1].lower().strip()
text = sp[1].strip()
if tokenizer_type == "pinyin":
text = convert_char_to_pinyin([text], polyphone=True)[0]
for t in text:
if t not in vocab and t not in miss_symbols_keep:
@@ -1230,8 +1196,8 @@ def infer(
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
ref_file=ref_audio,
ref_text=ref_text.lower().strip(),
gen_text=gen_text.lower().strip(),
ref_text=ref_text.strip(),
gen_text=gen_text.strip(),
nfe_step=nfe_step,
speed=speed,
remove_silence=remove_silence,
@@ -1496,7 +1462,9 @@ Using the extended model, you can finetune to a new language that is missing sym
txt_info_extend = gr.Textbox(label="Info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
check_button.click(
fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
)
extend_button.click(
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
)

View File

@@ -10,6 +10,7 @@ from f5_tts.model import CFM, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)