mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
Compare commits
81 Commits
1.1.0
...
f2a4f8581f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2a4f8581f | ||
|
|
a17c5ae435 | ||
|
|
a0b8fb5df2 | ||
|
|
c8bfc3aa3d | ||
|
|
8d3ec72159 | ||
|
|
65ada48a62 | ||
|
|
77d3ec623b | ||
|
|
186799d6dc | ||
|
|
31bb78f2ab | ||
|
|
e61824009a | ||
|
|
06a74910bd | ||
|
|
ac3c43595c | ||
|
|
605fa13b42 | ||
|
|
5f35f27230 | ||
|
|
c96c3aeed8 | ||
|
|
9b60fe6a34 | ||
|
|
a275798a2f | ||
|
|
efc7a7498b | ||
|
|
9842314127 | ||
|
|
69b0e0110e | ||
|
|
52c84776e5 | ||
|
|
ebbd7bd91f | ||
|
|
ac42286d04 | ||
|
|
d937efa6f3 | ||
|
|
8975fca803 | ||
|
|
8b0053ad0c | ||
|
|
b3ef4ed1d7 | ||
|
|
b1a9438496 | ||
|
|
0914170e98 | ||
|
|
c6ebad0220 | ||
|
|
cfaba6387f | ||
|
|
646f34b20f | ||
|
|
2e2acc6ea2 | ||
|
|
6fbe7592f5 | ||
|
|
7e37bc5d9a | ||
|
|
35f130ee85 | ||
|
|
e6469f705f | ||
|
|
31cd818095 | ||
|
|
1d13664b24 | ||
|
|
b27471ea06 | ||
|
|
8fb55f107e | ||
|
|
ccb380b752 | ||
|
|
3027b43953 | ||
|
|
ecd1c3949a | ||
|
|
2968aa184f | ||
|
|
fb26b6d93e | ||
|
|
f7f266cdd9 | ||
|
|
695c735737 | ||
|
|
3e2a07da1d | ||
|
|
c47687487c | ||
|
|
ac79d0ec1e | ||
|
|
dad398c0c1 | ||
|
|
3d969bf78d | ||
|
|
7c741c05f9 | ||
|
|
6d1a1e886a | ||
|
|
b4efcd836a | ||
|
|
818b868fab | ||
|
|
e6fee5e9ba | ||
|
|
2de214c122 | ||
|
|
2999f642ce | ||
|
|
03cff73343 | ||
|
|
63c513840d | ||
|
|
3e6b6c0c0c | ||
|
|
f00ac4d06b | ||
|
|
b0658bfd24 | ||
|
|
0cae51d646 | ||
|
|
95976041f2 | ||
|
|
ba1bf74215 | ||
|
|
536c29ac57 | ||
|
|
c4c61b0110 | ||
|
|
5f80fec160 | ||
|
|
178cb8afe6 | ||
|
|
761c7ed938 | ||
|
|
13fd6f8e07 | ||
|
|
b2284b6cff | ||
|
|
4b4359bc39 | ||
|
|
fe5c562212 | ||
|
|
2374f8ec39 | ||
|
|
f4f10bff6c | ||
|
|
9771ec6a3a | ||
|
|
4b3cd13382 |
18
.github/workflows/sync-hf.yaml
vendored
18
.github/workflows/sync-hf.yaml
vendored
@@ -1,18 +0,0 @@
|
||||
name: Sync to HF Space
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
trigger_curl:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Send cURL POST request
|
||||
run: |
|
||||
curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
|
||||
-s \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"
|
||||
@@ -3,11 +3,14 @@ repos:
|
||||
# Ruff version.
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
name: ruff linter
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
name: ruff formatter
|
||||
- id: ruff
|
||||
name: ruff sorter
|
||||
args: [--select, I, --fix]
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
|
||||
45
README.md
45
README.md
@@ -2,11 +2,12 @@
|
||||
|
||||
[](https://github.com/SWivid/F5-TTS)
|
||||
[](https://arxiv.org/abs/2410.06885)
|
||||
[](https://swivid.github.io/F5-TTS/)
|
||||
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
[](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
|
||||
[](https://x-lance.sjtu.edu.cn/)
|
||||
[](https://www.pcl.ac.cn)
|
||||
[](https://swivid.github.io/F5-TTS/)
|
||||
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
[](https://modelscope.cn/studios/AI-ModelScope/E2-F5-TTS)
|
||||
[](https://x-lance.sjtu.edu.cn/)
|
||||
[](https://www.sii.edu.cn/)
|
||||
[](https://www.pcl.ac.cn)
|
||||
<!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->
|
||||
|
||||
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
|
||||
@@ -26,8 +27,8 @@
|
||||
### Create a separate environment if needed
|
||||
|
||||
```bash
|
||||
# Create a python 3.10 conda env (you could also use virtualenv)
|
||||
conda create -n f5-tts python=3.10
|
||||
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
|
||||
conda create -n f5-tts python=3.11
|
||||
conda activate f5-tts
|
||||
```
|
||||
|
||||
@@ -91,7 +92,7 @@ conda activate f5-tts
|
||||
> ```bash
|
||||
> git clone https://github.com/SWivid/F5-TTS.git
|
||||
> cd F5-TTS
|
||||
> # git submodule update --init --recursive # (optional, if need > bigvgan)
|
||||
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
|
||||
> pip install -e .
|
||||
> ```
|
||||
|
||||
@@ -107,6 +108,21 @@ docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,targ
|
||||
docker container run --rm -it --gpus=all --mount 'type=volume,source=f5-tts,target=/root/.cache/huggingface/hub/' -p 7860:7860 ghcr.io/swivid/f5-tts:main f5-tts_infer-gradio --host 0.0.0.0
|
||||
```
|
||||
|
||||
### Runtime
|
||||
|
||||
Deployment solution with Triton and TensorRT-LLM.
|
||||
|
||||
#### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF | Mode |
|
||||
|---------------------|----------------|-------------|--------|-----------------|
|
||||
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
||||
|
||||
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
|
||||
## Inference
|
||||
|
||||
@@ -179,19 +195,6 @@ f5-tts_infer-cli -c custom.toml
|
||||
f5-tts_infer-cli -c src/f5_tts/infer/examples/multi/story.toml
|
||||
```
|
||||
|
||||
### 3. Runtime
|
||||
|
||||
Deployment solution with Triton and TensorRT-LLM.
|
||||
|
||||
#### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF |
|
||||
|-------|-------------|----------------|-------|
|
||||
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
||||
|
||||
See [detailed instructions](src/f5_tts/runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
|
||||
## Training
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.0"
|
||||
version = "1.1.9"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -15,17 +15,17 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"accelerate>=0.33.0",
|
||||
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
|
||||
"bitsandbytes>0.37.0; platform_machine!='arm64' and platform_system!='Darwin'",
|
||||
"cached_path",
|
||||
"click",
|
||||
"datasets",
|
||||
"ema_pytorch>=0.5.2",
|
||||
"gradio>=3.45.2",
|
||||
"gradio>=5.0.0",
|
||||
"hydra-core>=1.3.0",
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy<=1.26.4",
|
||||
"numpy<=1.26.4; python_version<='3.10'",
|
||||
"pydantic<=2.10.6",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
@@ -38,6 +38,7 @@ dependencies = [
|
||||
"tqdm>=4.65.0",
|
||||
"transformers",
|
||||
"transformers_stream_generator",
|
||||
"unidecode",
|
||||
"vocos",
|
||||
"wandb",
|
||||
"x_transformers>=1.31.14",
|
||||
|
||||
@@ -6,5 +6,5 @@ target-version = "py310"
|
||||
dummy-variable-rgx = "^_.*$"
|
||||
|
||||
[lint.isort]
|
||||
force-single-line = true
|
||||
force-single-line = false
|
||||
lines-after-imports = 2
|
||||
|
||||
@@ -9,13 +9,13 @@ from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
transcribe,
|
||||
preprocess_ref_audio_text,
|
||||
infer_process,
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
transcribe,
|
||||
)
|
||||
from f5_tts.model.utils import seed_everything
|
||||
|
||||
@@ -154,8 +154,8 @@ if __name__ == "__main__":
|
||||
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
ref_text="Some call me nature, others call me mother nature.",
|
||||
gen_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
|
||||
@@ -31,6 +31,8 @@ model:
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -31,6 +31,8 @@ model:
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -32,6 +32,8 @@ model:
|
||||
qk_norm: null # null | rms_norm
|
||||
conv_layers: 4
|
||||
pe_attn_head: null
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import argparse
|
||||
@@ -23,6 +24,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
@@ -146,10 +148,15 @@ def main():
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
if not os.path.exists(ckpt_path):
|
||||
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
|
||||
if os.path.exists(ckpt_prefix + ".pt"):
|
||||
ckpt_path = ckpt_prefix + ".pt"
|
||||
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
||||
ckpt_path = ckpt_prefix + ".safetensors"
|
||||
else:
|
||||
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
||||
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
|
||||
@@ -5,17 +5,16 @@ import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
from f5_tts.eval.utils_eval import (
|
||||
get_librispeech_test,
|
||||
run_asr_wer,
|
||||
run_sim,
|
||||
)
|
||||
|
||||
from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
@@ -5,17 +5,16 @@ import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import multiprocessing as mp
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
from f5_tts.eval.utils_eval import (
|
||||
get_seed_tts_test,
|
||||
run_asr_wer,
|
||||
run_sim,
|
||||
)
|
||||
|
||||
from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim
|
||||
|
||||
|
||||
rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
@@ -126,8 +126,13 @@ def get_inference_prompt(
|
||||
else:
|
||||
text_list = text
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
ref_mel_len = ref_mel.shape[-1]
|
||||
|
||||
if use_truth_duration:
|
||||
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||
if gt_sr != target_sample_rate:
|
||||
@@ -142,10 +147,6 @@ def get_inference_prompt(
|
||||
gen_text_len = len(gt_text.encode("utf-8"))
|
||||
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, (
|
||||
|
||||
@@ -13,7 +13,7 @@ To avoid possible inference failures, make sure you have seen through the follow
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
|
||||
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
|
||||
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
|
||||
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
|
||||
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
|
||||
|
||||
|
||||
@@ -129,6 +129,28 @@ ref_text = ""
|
||||
```
|
||||
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
||||
|
||||
## API Usage
|
||||
|
||||
```python
|
||||
from importlib.resources import files
|
||||
from f5_tts.api import F5TTS
|
||||
|
||||
f5tts = F5TTS()
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
)
|
||||
```
|
||||
Check [api.py](../api.py) for more details.
|
||||
|
||||
## TensorRT-LLM Deployment
|
||||
|
||||
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
## Socket Real-time Service
|
||||
|
||||
Real-time voice output with chunk stream:
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
||||
- [French](#french)
|
||||
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
|
||||
- [German](#german)
|
||||
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
|
||||
- [Hindi](#hindi)
|
||||
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
|
||||
- [Italian](#italian)
|
||||
@@ -97,6 +99,22 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
||||
|
||||
|
||||
## German
|
||||
|
||||
#### F5-TTS Base @ de @ hvoss-techfak
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
|
||||
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
|
||||
|
||||
|
||||
## Hindi
|
||||
|
||||
#### F5-TTS Small @ hi @ SPRINGLab
|
||||
|
||||
@@ -13,8 +13,8 @@ output_file = "infer_cli_story.wav"
|
||||
[voices.town]
|
||||
ref_audio = "infer/examples/multi/town.flac"
|
||||
ref_text = ""
|
||||
speed = 0.8 # will ignore global speed
|
||||
|
||||
[voices.country]
|
||||
ref_audio = "infer/examples/multi/country.flac"
|
||||
ref_text = ""
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
|
||||
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."
|
||||
@@ -12,22 +12,23 @@ import tomli
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
from unidecode import unidecode
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
mel_spec_type,
|
||||
target_rms,
|
||||
cross_fade_duration,
|
||||
nfe_step,
|
||||
cfg_strength,
|
||||
sway_sampling_coef,
|
||||
speed,
|
||||
fix_duration,
|
||||
cross_fade_duration,
|
||||
device,
|
||||
fix_duration,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
mel_spec_type,
|
||||
nfe_step,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
speed,
|
||||
sway_sampling_coef,
|
||||
target_rms,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,6 +113,11 @@ parser.add_argument(
|
||||
action="store_true",
|
||||
help="To save each audio chunks during inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_legacy_text",
|
||||
action="store_false",
|
||||
help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_silence",
|
||||
action="store_true",
|
||||
@@ -197,6 +203,12 @@ output_file = args.output_file or config.get(
|
||||
)
|
||||
|
||||
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
||||
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
|
||||
if save_chunk and use_legacy_text:
|
||||
print(
|
||||
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
|
||||
)
|
||||
|
||||
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
||||
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
|
||||
|
||||
@@ -321,9 +333,10 @@ def main():
|
||||
text = re.sub(reg2, "", text)
|
||||
ref_audio_ = voices[voice]["ref_audio"]
|
||||
ref_text_ = voices[voice]["ref_text"]
|
||||
local_speed = voices[voice].get("speed", speed)
|
||||
gen_text_ = text.strip()
|
||||
print(f"Voice: {voice}")
|
||||
audio_segment, final_sample_rate, spectragram = infer_process(
|
||||
audio_segment, final_sample_rate, spectrogram = infer_process(
|
||||
ref_audio_,
|
||||
ref_text_,
|
||||
gen_text_,
|
||||
@@ -335,7 +348,7 @@ def main():
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
speed=local_speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
@@ -344,6 +357,8 @@ def main():
|
||||
if save_chunk:
|
||||
if len(gen_text_) > 200:
|
||||
gen_text_ = gen_text_[:200] + " ... "
|
||||
if use_legacy_text:
|
||||
gen_text_ = unidecode(gen_text_)
|
||||
sf.write(
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
||||
audio_segment,
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from importlib.resources import files
|
||||
|
||||
import click
|
||||
@@ -17,6 +19,7 @@ import torchaudio
|
||||
from cached_path import cached_path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
import spaces
|
||||
|
||||
@@ -32,15 +35,16 @@ def gpu_decorator(func):
|
||||
return func
|
||||
|
||||
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.infer.utils_infer import (
|
||||
load_vocoder,
|
||||
load_model,
|
||||
preprocess_ref_audio_text,
|
||||
infer_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
tempfile_kwargs,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT
|
||||
|
||||
|
||||
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
||||
@@ -78,6 +82,8 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
|
||||
vocab_path = str(cached_path(vocab_path))
|
||||
if model_cfg is None:
|
||||
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
||||
elif isinstance(model_cfg, str):
|
||||
model_cfg = json.loads(model_cfg)
|
||||
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
||||
|
||||
|
||||
@@ -90,7 +96,7 @@ chat_tokenizer_state = None
|
||||
|
||||
|
||||
@gpu_decorator
|
||||
def generate_response(messages, model, tokenizer):
|
||||
def chat_model_inference(messages, model, tokenizer):
|
||||
"""Generate response using Qwen"""
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
@@ -112,6 +118,17 @@ def generate_response(messages, model, tokenizer):
|
||||
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
@gpu_decorator
|
||||
def load_text_from_file(file):
|
||||
if file:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
text = f.read().strip()
|
||||
else:
|
||||
text = ""
|
||||
return gr.update(value=text)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
|
||||
@gpu_decorator
|
||||
def infer(
|
||||
ref_audio_orig,
|
||||
@@ -119,6 +136,7 @@ def infer(
|
||||
gen_text,
|
||||
model,
|
||||
remove_silence,
|
||||
seed,
|
||||
cross_fade_duration=0.15,
|
||||
nfe_step=32,
|
||||
speed=1,
|
||||
@@ -128,8 +146,15 @@ def infer(
|
||||
gr.Warning("Please provide reference audio.")
|
||||
return gr.update(), gr.update(), ref_text
|
||||
|
||||
# Set inference seed
|
||||
if seed < 0 or seed > 2**31 - 1:
|
||||
gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
|
||||
seed = np.random.randint(0, 2**31 - 1)
|
||||
torch.manual_seed(seed)
|
||||
used_seed = seed
|
||||
|
||||
if not gen_text.strip():
|
||||
gr.Warning("Please enter text to generate.")
|
||||
gr.Warning("Please enter text to generate or upload a text file.")
|
||||
return gr.update(), gr.update(), ref_text
|
||||
|
||||
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
||||
@@ -142,7 +167,7 @@ def infer(
|
||||
show_info("Loading E2-TTS model...")
|
||||
E2TTS_ema_model = load_e2tts()
|
||||
ema_model = E2TTS_ema_model
|
||||
elif isinstance(model, list) and model[0] == "Custom":
|
||||
elif isinstance(model, tuple) and model[0] == "Custom":
|
||||
assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
|
||||
global custom_ema_model, pre_custom_path
|
||||
if pre_custom_path != model[1]:
|
||||
@@ -166,44 +191,59 @@ def infer(
|
||||
|
||||
# Remove silence
|
||||
if remove_silence:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
sf.write(f.name, final_wave, final_sample_rate)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
|
||||
temp_path = f.name
|
||||
try:
|
||||
sf.write(temp_path, final_wave, final_sample_rate)
|
||||
remove_silence_for_generated_wav(f.name)
|
||||
final_wave, _ = torchaudio.load(f.name)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
final_wave = final_wave.squeeze().cpu().numpy()
|
||||
|
||||
# Save the spectrogram
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
|
||||
spectrogram_path = tmp_spectrogram.name
|
||||
save_spectrogram(combined_spectrogram, spectrogram_path)
|
||||
save_spectrogram(combined_spectrogram, spectrogram_path)
|
||||
|
||||
return (final_sample_rate, final_wave), spectrogram_path, ref_text
|
||||
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
|
||||
|
||||
|
||||
with gr.Blocks() as app_credits:
|
||||
gr.Markdown("""
|
||||
# Credits
|
||||
|
||||
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
||||
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
||||
""")
|
||||
with gr.Blocks() as app_tts:
|
||||
gr.Markdown("# Batched TTS")
|
||||
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
|
||||
with gr.Row():
|
||||
gen_text_input = gr.Textbox(
|
||||
label="Text to Generate",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
scale=4,
|
||||
)
|
||||
gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
|
||||
generate_btn = gr.Button("Synthesize", variant="primary")
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
ref_text_input = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
|
||||
lines=2,
|
||||
)
|
||||
remove_silence = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
|
||||
value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
ref_text_input = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
|
||||
lines=2,
|
||||
scale=4,
|
||||
)
|
||||
ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
|
||||
with gr.Row():
|
||||
randomize_seed = gr.Checkbox(
|
||||
label="Randomize Seed",
|
||||
info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
|
||||
value=True,
|
||||
scale=3,
|
||||
)
|
||||
seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
|
||||
with gr.Column(scale=4):
|
||||
remove_silence = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
info="If undesired long silence(s) produced, turn on to automatically detect and crop.",
|
||||
value=False,
|
||||
)
|
||||
speed_slider = gr.Slider(
|
||||
label="Speed",
|
||||
minimum=0.3,
|
||||
@@ -238,21 +278,45 @@ with gr.Blocks() as app_tts:
|
||||
ref_text_input,
|
||||
gen_text_input,
|
||||
remove_silence,
|
||||
randomize_seed,
|
||||
seed_input,
|
||||
cross_fade_duration_slider,
|
||||
nfe_slider,
|
||||
speed_slider,
|
||||
):
|
||||
audio_out, spectrogram_path, ref_text_out = infer(
|
||||
if randomize_seed:
|
||||
seed_input = np.random.randint(0, 2**31 - 1)
|
||||
|
||||
audio_out, spectrogram_path, ref_text_out, used_seed = infer(
|
||||
ref_audio_input,
|
||||
ref_text_input,
|
||||
gen_text_input,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed_input,
|
||||
cross_fade_duration=cross_fade_duration_slider,
|
||||
nfe_step=nfe_slider,
|
||||
speed=speed_slider,
|
||||
)
|
||||
return audio_out, spectrogram_path, ref_text_out
|
||||
return audio_out, spectrogram_path, ref_text_out, used_seed
|
||||
|
||||
gen_text_file.upload(
|
||||
load_text_from_file,
|
||||
inputs=[gen_text_file],
|
||||
outputs=[gen_text_input],
|
||||
)
|
||||
|
||||
ref_text_file.upload(
|
||||
load_text_from_file,
|
||||
inputs=[ref_text_file],
|
||||
outputs=[ref_text_input],
|
||||
)
|
||||
|
||||
ref_audio_input.clear(
|
||||
lambda: [None, None],
|
||||
None,
|
||||
[ref_text_input, ref_text_file],
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
basic_tts,
|
||||
@@ -261,35 +325,46 @@ with gr.Blocks() as app_tts:
|
||||
ref_text_input,
|
||||
gen_text_input,
|
||||
remove_silence,
|
||||
randomize_seed,
|
||||
seed_input,
|
||||
cross_fade_duration_slider,
|
||||
nfe_slider,
|
||||
speed_slider,
|
||||
],
|
||||
outputs=[audio_output, spectrogram_output, ref_text_input],
|
||||
outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
|
||||
)
|
||||
|
||||
|
||||
def parse_speechtypes_text(gen_text):
|
||||
# Pattern to find {speechtype}
|
||||
pattern = r"\{(.*?)\}"
|
||||
# Pattern to find {str} or {"name": str, "seed": int, "speed": float}
|
||||
pattern = r"(\{.*?\})"
|
||||
|
||||
# Split the text by the pattern
|
||||
tokens = re.split(pattern, gen_text)
|
||||
|
||||
segments = []
|
||||
|
||||
current_style = "Regular"
|
||||
current_type_dict = {
|
||||
"name": "Regular",
|
||||
"seed": -1,
|
||||
"speed": 1.0,
|
||||
}
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if i % 2 == 0:
|
||||
# This is text
|
||||
text = tokens[i].strip()
|
||||
if text:
|
||||
segments.append({"style": current_style, "text": text})
|
||||
current_type_dict["text"] = text
|
||||
segments.append(current_type_dict)
|
||||
else:
|
||||
# This is style
|
||||
style = tokens[i].strip()
|
||||
current_style = style
|
||||
# This is type
|
||||
type_str = tokens[i].strip()
|
||||
try: # if type dict
|
||||
current_type_dict = json.loads(type_str)
|
||||
except json.decoder.JSONDecodeError:
|
||||
type_str = type_str[1:-1] # remove brace {}
|
||||
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
|
||||
|
||||
return segments
|
||||
|
||||
@@ -300,44 +375,55 @@ with gr.Blocks() as app_multistyle:
|
||||
"""
|
||||
# Multiple Speech-Type Generation
|
||||
|
||||
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
|
||||
This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
"""
|
||||
**Example Input:**
|
||||
{Regular} Hello, I'd like to order a sandwich please.
|
||||
{Surprised} What do you mean you're out of bread?
|
||||
{Sad} I really wanted a sandwich though...
|
||||
{Angry} You know what, darn you and your little shop!
|
||||
{Whisper} I'll just go back home and cry now.
|
||||
{Shouting} Why me?!
|
||||
**Example Input:** <br>
|
||||
{Regular} Hello, I'd like to order a sandwich please. <br>
|
||||
{Surprised} What do you mean you're out of bread? <br>
|
||||
{Sad} I really wanted a sandwich though... <br>
|
||||
{Angry} You know what, darn you and your little shop! <br>
|
||||
{Whisper} I'll just go back home and cry now. <br>
|
||||
{Shouting} Why me?!
|
||||
"""
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
**Example Input 2:**
|
||||
{Speaker1_Happy} Hello, I'd like to order a sandwich please.
|
||||
{Speaker2_Regular} Sorry, we're out of bread.
|
||||
{Speaker1_Sad} I really wanted a sandwich though...
|
||||
{Speaker2_Whisper} I'll give you the last one I was hiding.
|
||||
**Example Input 2:** <br>
|
||||
{"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
|
||||
{"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
|
||||
{"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
|
||||
{"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
|
||||
"""
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
|
||||
'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
|
||||
)
|
||||
|
||||
# Regular speech type (mandatory)
|
||||
with gr.Row() as regular_row:
|
||||
with gr.Column():
|
||||
with gr.Row(variant="compact") as regular_row:
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
|
||||
regular_insert = gr.Button("Insert Label", variant="secondary")
|
||||
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
|
||||
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
|
||||
with gr.Column(scale=3):
|
||||
regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
|
||||
with gr.Column(scale=3):
|
||||
regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
|
||||
with gr.Row():
|
||||
regular_seed_slider = gr.Slider(
|
||||
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
|
||||
)
|
||||
regular_speed_slider = gr.Slider(
|
||||
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
|
||||
)
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
|
||||
|
||||
# Regular speech type (max 100)
|
||||
max_speech_types = 100
|
||||
@@ -345,25 +431,55 @@ with gr.Blocks() as app_multistyle:
|
||||
speech_type_names = [regular_name]
|
||||
speech_type_audios = [regular_audio]
|
||||
speech_type_ref_texts = [regular_ref_text]
|
||||
speech_type_ref_text_files = [regular_ref_text_file]
|
||||
speech_type_seeds = [regular_seed_slider]
|
||||
speech_type_speeds = [regular_speed_slider]
|
||||
speech_type_delete_btns = [None]
|
||||
speech_type_insert_btns = [regular_insert]
|
||||
|
||||
# Additional speech types (99 more)
|
||||
for i in range(max_speech_types - 1):
|
||||
with gr.Row(visible=False) as row:
|
||||
with gr.Column():
|
||||
with gr.Row(variant="compact", visible=False) as row:
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
name_input = gr.Textbox(label="Speech Type Name")
|
||||
delete_btn = gr.Button("Delete Type", variant="secondary")
|
||||
insert_btn = gr.Button("Insert Label", variant="secondary")
|
||||
audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
ref_text_input = gr.Textbox(label="Reference Text", lines=2)
|
||||
delete_btn = gr.Button("Delete Type", variant="stop")
|
||||
with gr.Column(scale=3):
|
||||
audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
with gr.Column(scale=3):
|
||||
ref_text_input = gr.Textbox(label="Reference Text", lines=4)
|
||||
with gr.Row():
|
||||
seed_input = gr.Slider(
|
||||
show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
|
||||
)
|
||||
speed_input = gr.Slider(
|
||||
show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
|
||||
)
|
||||
with gr.Column(scale=1, min_width=160):
|
||||
ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
|
||||
speech_type_rows.append(row)
|
||||
speech_type_names.append(name_input)
|
||||
speech_type_audios.append(audio_input)
|
||||
speech_type_ref_texts.append(ref_text_input)
|
||||
speech_type_ref_text_files.append(ref_text_file_input)
|
||||
speech_type_seeds.append(seed_input)
|
||||
speech_type_speeds.append(speed_input)
|
||||
speech_type_delete_btns.append(delete_btn)
|
||||
speech_type_insert_btns.append(insert_btn)
|
||||
|
||||
# Global logic for all speech types
|
||||
for i in range(max_speech_types):
|
||||
speech_type_audios[i].clear(
|
||||
lambda: [None, None],
|
||||
None,
|
||||
[speech_type_ref_texts[i], speech_type_ref_text_files[i]],
|
||||
)
|
||||
speech_type_ref_text_files[i].upload(
|
||||
load_text_from_file,
|
||||
inputs=[speech_type_ref_text_files[i]],
|
||||
outputs=[speech_type_ref_texts[i]],
|
||||
)
|
||||
|
||||
# Button to add speech type
|
||||
add_speech_type_btn = gr.Button("Add Speech Type")
|
||||
|
||||
@@ -385,27 +501,44 @@ with gr.Blocks() as app_multistyle:
|
||||
|
||||
# Function to delete a speech type
|
||||
def delete_speech_type_fn():
|
||||
return gr.update(visible=False), None, None, None
|
||||
return gr.update(visible=False), None, None, None, None
|
||||
|
||||
# Update delete button clicks
|
||||
# Update delete button clicks and ref text file changes
|
||||
for i in range(1, len(speech_type_delete_btns)):
|
||||
speech_type_delete_btns[i].click(
|
||||
delete_speech_type_fn,
|
||||
outputs=[speech_type_rows[i], speech_type_names[i], speech_type_audios[i], speech_type_ref_texts[i]],
|
||||
outputs=[
|
||||
speech_type_rows[i],
|
||||
speech_type_names[i],
|
||||
speech_type_audios[i],
|
||||
speech_type_ref_texts[i],
|
||||
speech_type_ref_text_files[i],
|
||||
],
|
||||
)
|
||||
|
||||
# Text input for the prompt
|
||||
gen_text_input_multistyle = gr.Textbox(
|
||||
label="Text to Generate",
|
||||
lines=10,
|
||||
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
|
||||
)
|
||||
with gr.Row():
|
||||
gen_text_input_multistyle = gr.Textbox(
|
||||
label="Text to Generate",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
scale=4,
|
||||
placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
|
||||
)
|
||||
gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
|
||||
|
||||
def make_insert_speech_type_fn(index):
|
||||
def insert_speech_type_fn(current_text, speech_type_name):
|
||||
def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
|
||||
current_text = current_text or ""
|
||||
speech_type_name = speech_type_name or "None"
|
||||
updated_text = current_text + f"{{{speech_type_name}}} "
|
||||
if not speech_type_name:
|
||||
gr.Warning("Please enter speech type name before insert.")
|
||||
return current_text
|
||||
speech_type_dict = {
|
||||
"name": speech_type_name,
|
||||
"seed": speech_type_seed,
|
||||
"speed": speech_type_speed,
|
||||
}
|
||||
updated_text = current_text + json.dumps(speech_type_dict) + " "
|
||||
return updated_text
|
||||
|
||||
return insert_speech_type_fn
|
||||
@@ -414,15 +547,24 @@ with gr.Blocks() as app_multistyle:
|
||||
insert_fn = make_insert_speech_type_fn(i)
|
||||
insert_btn.click(
|
||||
insert_fn,
|
||||
inputs=[gen_text_input_multistyle, speech_type_names[i]],
|
||||
inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
|
||||
outputs=gen_text_input_multistyle,
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
remove_silence_multistyle = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
value=True,
|
||||
)
|
||||
with gr.Accordion("Advanced Settings", open=True):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
show_cherrypick_multistyle = gr.Checkbox(
|
||||
label="Show Cherry-pick Interface",
|
||||
info="Turn on to show interface, picking seeds from previous generations.",
|
||||
value=False,
|
||||
)
|
||||
with gr.Column():
|
||||
remove_silence_multistyle = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
info="Turn on to automatically detect and crop long silences.",
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Generate button
|
||||
generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
|
||||
@@ -430,6 +572,30 @@ with gr.Blocks() as app_multistyle:
|
||||
# Output audio
|
||||
audio_output_multistyle = gr.Audio(label="Synthesized Audio")
|
||||
|
||||
# Used seed gallery
|
||||
cherrypick_interface_multistyle = gr.Textbox(
|
||||
label="Cherry-pick Interface",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
show_copy_button=True,
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
# Logic control to show/hide the cherrypick interface
|
||||
show_cherrypick_multistyle.change(
|
||||
lambda is_visible: gr.update(visible=is_visible),
|
||||
show_cherrypick_multistyle,
|
||||
cherrypick_interface_multistyle,
|
||||
)
|
||||
|
||||
# Function to load text to generate from file
|
||||
gen_text_file_multistyle.upload(
|
||||
load_text_from_file,
|
||||
inputs=[gen_text_file_multistyle],
|
||||
outputs=[gen_text_input_multistyle],
|
||||
)
|
||||
|
||||
@gpu_decorator
|
||||
def generate_multistyle_speech(
|
||||
gen_text,
|
||||
@@ -457,41 +623,60 @@ with gr.Blocks() as app_multistyle:
|
||||
|
||||
# For each segment, generate speech
|
||||
generated_audio_segments = []
|
||||
current_style = "Regular"
|
||||
current_type_name = "Regular"
|
||||
inference_meta_data = ""
|
||||
|
||||
for segment in segments:
|
||||
style = segment["style"]
|
||||
name = segment["name"]
|
||||
seed_input = segment["seed"]
|
||||
speed = segment["speed"]
|
||||
text = segment["text"]
|
||||
|
||||
if style in speech_types:
|
||||
current_style = style
|
||||
if name in speech_types:
|
||||
current_type_name = name
|
||||
else:
|
||||
gr.Warning(f"Type {style} is not available, will use Regular as default.")
|
||||
current_style = "Regular"
|
||||
gr.Warning(f"Type {name} is not available, will use Regular as default.")
|
||||
current_type_name = "Regular"
|
||||
|
||||
try:
|
||||
ref_audio = speech_types[current_style]["audio"]
|
||||
ref_audio = speech_types[current_type_name]["audio"]
|
||||
except KeyError:
|
||||
gr.Warning(f"Please provide reference audio for type {current_style}.")
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
ref_text = speech_types[current_style].get("ref_text", "")
|
||||
gr.Warning(f"Please provide reference audio for type {current_type_name}.")
|
||||
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
|
||||
ref_text = speech_types[current_type_name].get("ref_text", "")
|
||||
|
||||
# Generate speech for this segment
|
||||
audio_out, _, ref_text_out = infer(
|
||||
ref_audio, ref_text, text, tts_model_choice, remove_silence, 0, show_info=print
|
||||
) # show_info=print no pull to top when generating
|
||||
if seed_input == -1:
|
||||
seed_input = np.random.randint(0, 2**31 - 1)
|
||||
|
||||
# Generate or retrieve speech for this segment
|
||||
audio_out, _, ref_text_out, used_seed = infer(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
text,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed_input,
|
||||
cross_fade_duration=0,
|
||||
speed=speed,
|
||||
show_info=print, # no pull to top when generating
|
||||
)
|
||||
sr, audio_data = audio_out
|
||||
|
||||
generated_audio_segments.append(audio_data)
|
||||
speech_types[current_style]["ref_text"] = ref_text_out
|
||||
speech_types[current_type_name]["ref_text"] = ref_text_out
|
||||
inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
|
||||
|
||||
# Concatenate all audio segments
|
||||
if generated_audio_segments:
|
||||
final_audio_data = np.concatenate(generated_audio_segments)
|
||||
return [(sr, final_audio_data)] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
return (
|
||||
[(sr, final_audio_data)]
|
||||
+ [speech_types[name]["ref_text"] for name in speech_types]
|
||||
+ [inference_meta_data]
|
||||
)
|
||||
else:
|
||||
gr.Warning("No audio generated.")
|
||||
return [None] + [speech_types[style]["ref_text"] for style in speech_types]
|
||||
return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
|
||||
|
||||
generate_multistyle_btn.click(
|
||||
generate_multistyle_speech,
|
||||
@@ -504,7 +689,7 @@ with gr.Blocks() as app_multistyle:
|
||||
+ [
|
||||
remove_silence_multistyle,
|
||||
],
|
||||
outputs=[audio_output_multistyle] + speech_type_ref_texts,
|
||||
outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
|
||||
)
|
||||
|
||||
# Validation function to disable Generate button if speech types are missing
|
||||
@@ -521,7 +706,7 @@ with gr.Blocks() as app_multistyle:
|
||||
|
||||
# Parse the gen_text to get the speech types used
|
||||
segments = parse_speechtypes_text(gen_text)
|
||||
speech_types_in_text = set(segment["style"] for segment in segments)
|
||||
speech_types_in_text = set(segment["name"] for segment in segments)
|
||||
|
||||
# Check if all speech types in text are available
|
||||
missing_speech_types = speech_types_in_text - speech_types_available
|
||||
@@ -544,10 +729,10 @@ with gr.Blocks() as app_chat:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Voice Chat
|
||||
Have a conversation with an AI using your reference voice!
|
||||
1. Upload a reference audio clip and optionally its transcript.
|
||||
Have a conversation with an AI using your reference voice!
|
||||
1. Upload a reference audio clip and optionally its transcript (via text or .txt file).
|
||||
2. Load the chat model.
|
||||
3. Record your message through your microphone.
|
||||
3. Record your message through your microphone or type it.
|
||||
4. The AI will respond using the reference voice.
|
||||
"""
|
||||
)
|
||||
@@ -603,22 +788,35 @@ Have a conversation with an AI using your reference voice!
|
||||
ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
|
||||
with gr.Column():
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
with gr.Row():
|
||||
ref_text_chat = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Optional: Leave blank to auto-transcribe",
|
||||
lines=2,
|
||||
scale=3,
|
||||
)
|
||||
ref_text_file_chat = gr.File(
|
||||
label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
|
||||
)
|
||||
with gr.Row():
|
||||
randomize_seed_chat = gr.Checkbox(
|
||||
label="Randomize Seed",
|
||||
value=True,
|
||||
info="Uncheck to use the seed specified.",
|
||||
scale=3,
|
||||
)
|
||||
seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
|
||||
remove_silence_chat = gr.Checkbox(
|
||||
label="Remove Silences",
|
||||
value=True,
|
||||
)
|
||||
ref_text_chat = gr.Textbox(
|
||||
label="Reference Text",
|
||||
info="Optional: Leave blank to auto-transcribe",
|
||||
lines=2,
|
||||
)
|
||||
system_prompt_chat = gr.Textbox(
|
||||
label="System Prompt",
|
||||
value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
lines=2,
|
||||
)
|
||||
|
||||
chatbot_interface = gr.Chatbot(label="Conversation")
|
||||
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -635,140 +833,119 @@ Have a conversation with an AI using your reference voice!
|
||||
send_btn_chat = gr.Button("Send Message")
|
||||
clear_btn_chat = gr.Button("Clear Conversation")
|
||||
|
||||
conversation_state = gr.State(
|
||||
value=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Modify process_audio_input to use model and tokenizer from state
|
||||
# Modify process_audio_input to generate user input
|
||||
@gpu_decorator
|
||||
def process_audio_input(audio_path, text, history, conv_state):
|
||||
def process_audio_input(conv_state, audio_path, text):
|
||||
"""Handle audio or text input from user"""
|
||||
|
||||
if not audio_path and not text.strip():
|
||||
return history, conv_state, ""
|
||||
return conv_state
|
||||
|
||||
if audio_path:
|
||||
text = preprocess_ref_audio_text(audio_path, text)[1]
|
||||
|
||||
if not text.strip():
|
||||
return history, conv_state, ""
|
||||
return conv_state
|
||||
|
||||
conv_state.append({"role": "user", "content": text})
|
||||
history.append((text, None))
|
||||
return conv_state
|
||||
|
||||
response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
|
||||
# Use model and tokenizer from state to get text response
|
||||
@gpu_decorator
|
||||
def generate_text_response(conv_state, system_prompt):
|
||||
"""Generate text response from AI"""
|
||||
|
||||
system_prompt_state = [{"role": "system", "content": system_prompt}]
|
||||
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
|
||||
|
||||
conv_state.append({"role": "assistant", "content": response})
|
||||
history[-1] = (text, response)
|
||||
|
||||
return history, conv_state, ""
|
||||
return conv_state
|
||||
|
||||
@gpu_decorator
|
||||
def generate_audio_response(history, ref_audio, ref_text, remove_silence):
|
||||
def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
|
||||
"""Generate TTS audio for AI response"""
|
||||
if not history or not ref_audio:
|
||||
return None
|
||||
if not conv_state or not ref_audio:
|
||||
return None, ref_text, seed_input
|
||||
|
||||
last_user_message, last_ai_response = history[-1]
|
||||
if not last_ai_response:
|
||||
return None
|
||||
last_ai_response = conv_state[-1]["content"]
|
||||
if not last_ai_response or conv_state[-1]["role"] != "assistant":
|
||||
return None, ref_text, seed_input
|
||||
|
||||
audio_result, _, ref_text_out = infer(
|
||||
if randomize_seed:
|
||||
seed_input = np.random.randint(0, 2**31 - 1)
|
||||
|
||||
audio_result, _, ref_text_out, used_seed = infer(
|
||||
ref_audio,
|
||||
ref_text,
|
||||
last_ai_response,
|
||||
tts_model_choice,
|
||||
remove_silence,
|
||||
seed=seed_input,
|
||||
cross_fade_duration=0.15,
|
||||
speed=1.0,
|
||||
show_info=print, # show_info=print no pull to top when generating
|
||||
)
|
||||
return audio_result, ref_text_out
|
||||
return audio_result, ref_text_out, used_seed
|
||||
|
||||
def clear_conversation():
|
||||
"""Reset the conversation"""
|
||||
return [], [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
|
||||
}
|
||||
]
|
||||
return [], None
|
||||
|
||||
def update_system_prompt(new_prompt):
|
||||
"""Update the system prompt and reset the conversation"""
|
||||
new_conv_state = [{"role": "system", "content": new_prompt}]
|
||||
return [], new_conv_state
|
||||
|
||||
# Handle audio input
|
||||
audio_input_chat.stop_recording(
|
||||
process_audio_input,
|
||||
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
).then(
|
||||
lambda: None,
|
||||
None,
|
||||
audio_input_chat,
|
||||
ref_text_file_chat.upload(
|
||||
load_text_from_file,
|
||||
inputs=[ref_text_file_chat],
|
||||
outputs=[ref_text_chat],
|
||||
)
|
||||
|
||||
# Handle text input
|
||||
text_input_chat.submit(
|
||||
process_audio_input,
|
||||
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
).then(
|
||||
lambda: None,
|
||||
None,
|
||||
text_input_chat,
|
||||
)
|
||||
for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
|
||||
user_operation(
|
||||
process_audio_input,
|
||||
inputs=[chatbot_interface, audio_input_chat, text_input_chat],
|
||||
outputs=[chatbot_interface],
|
||||
).then(
|
||||
generate_text_response,
|
||||
inputs=[chatbot_interface, system_prompt_chat],
|
||||
outputs=[chatbot_interface],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[
|
||||
chatbot_interface,
|
||||
ref_audio_chat,
|
||||
ref_text_chat,
|
||||
remove_silence_chat,
|
||||
randomize_seed_chat,
|
||||
seed_input_chat,
|
||||
],
|
||||
outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
|
||||
).then(
|
||||
lambda: [None, None],
|
||||
None,
|
||||
[audio_input_chat, text_input_chat],
|
||||
)
|
||||
|
||||
# Handle send button
|
||||
send_btn_chat.click(
|
||||
process_audio_input,
|
||||
inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
).then(
|
||||
generate_audio_response,
|
||||
inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, remove_silence_chat],
|
||||
outputs=[audio_output_chat, ref_text_chat],
|
||||
).then(
|
||||
lambda: None,
|
||||
None,
|
||||
text_input_chat,
|
||||
)
|
||||
# Handle clear button or system prompt change and reset conversation
|
||||
for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
|
||||
user_operation(
|
||||
clear_conversation,
|
||||
outputs=[chatbot_interface, audio_output_chat],
|
||||
)
|
||||
|
||||
# Handle clear button
|
||||
clear_btn_chat.click(
|
||||
clear_conversation,
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
)
|
||||
|
||||
# Handle system prompt change and reset conversation
|
||||
system_prompt_chat.change(
|
||||
update_system_prompt,
|
||||
inputs=system_prompt_chat,
|
||||
outputs=[chatbot_interface, conversation_state],
|
||||
)
|
||||
with gr.Blocks() as app_credits:
|
||||
gr.Markdown("""
|
||||
# Credits
|
||||
|
||||
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
||||
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
||||
""")
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
f"""
|
||||
# E2/F5 TTS
|
||||
# F5-TTS Demo Space
|
||||
|
||||
This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
|
||||
This is {"a local web UI for [F5-TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
|
||||
|
||||
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
||||
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
||||
@@ -798,7 +975,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
global tts_model_choice
|
||||
if new_choice == "Custom": # override in case webpage is refreshed
|
||||
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
|
||||
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
||||
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
|
||||
return (
|
||||
gr.update(visible=True, value=custom_ckpt_path),
|
||||
gr.update(visible=True, value=custom_vocab_path),
|
||||
@@ -810,7 +987,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
|
||||
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
|
||||
global tts_model_choice
|
||||
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
||||
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
|
||||
with open(last_used_custom, "w", encoding="utf-8") as f:
|
||||
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
|
||||
from importlib.resources import files
|
||||
@@ -7,6 +8,7 @@ from importlib.resources import files
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
@@ -14,6 +16,7 @@ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectro
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
@@ -55,7 +58,8 @@ win_length = model_cfg.model.mel_spec.win_length
|
||||
n_fft = model_cfg.model.mel_spec.n_fft
|
||||
|
||||
|
||||
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
# ckpt_path = str(files("f5_tts").joinpath("../../")) + f"/ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
ckpt_path = str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
output_dir = "tests"
|
||||
|
||||
|
||||
@@ -152,7 +156,7 @@ for part in parts_to_edit:
|
||||
dim=-1,
|
||||
)
|
||||
offset = end * target_sample_rate
|
||||
# audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
|
||||
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
||||
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
||||
audio = audio.to(device)
|
||||
edit_mask = edit_mask.to(device)
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
||||
|
||||
@@ -14,6 +15,7 @@ from importlib.resources import files
|
||||
|
||||
import matplotlib
|
||||
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
import matplotlib.pylab as plt
|
||||
@@ -27,12 +29,11 @@ from transformers import pipeline
|
||||
from vocos import Vocos
|
||||
|
||||
from f5_tts.model import CFM
|
||||
from f5_tts.model.utils import (
|
||||
get_tokenizer,
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
_ref_audio_cache = {}
|
||||
_ref_text_cache = {}
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
@@ -44,6 +45,8 @@ device = (
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
|
||||
|
||||
# -----------------------------------------
|
||||
|
||||
target_sample_rate = 24000
|
||||
@@ -290,62 +293,74 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
|
||||
# Compute a hash of the reference audio file
|
||||
with open(ref_audio_orig, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
audio_hash = hashlib.md5(audio_data).hexdigest()
|
||||
|
||||
global _ref_audio_cache
|
||||
|
||||
if audio_hash in _ref_audio_cache:
|
||||
show_info("Using cached preprocessed reference audio...")
|
||||
ref_audio = _ref_audio_cache[audio_hash]
|
||||
|
||||
else: # first pass, do preprocess
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
|
||||
temp_path = f.name
|
||||
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
if clip_short:
|
||||
# 1. try to find long silence for clipping
|
||||
# 1. try to find long silence for clipping
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
aseg = non_silent_wave
|
||||
|
||||
aseg = non_silent_wave
|
||||
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
aseg.export(f.name, format="wav")
|
||||
ref_audio = f.name
|
||||
aseg.export(temp_path, format="wav")
|
||||
ref_audio = temp_path
|
||||
|
||||
# Compute a hash of the reference audio file
|
||||
with open(ref_audio, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
audio_hash = hashlib.md5(audio_data).hexdigest()
|
||||
# Cache the processed reference audio
|
||||
_ref_audio_cache[audio_hash] = ref_audio
|
||||
|
||||
if not ref_text.strip():
|
||||
global _ref_audio_cache
|
||||
if audio_hash in _ref_audio_cache:
|
||||
global _ref_text_cache
|
||||
if audio_hash in _ref_text_cache:
|
||||
# Use cached asr transcription
|
||||
show_info("Using cached reference text...")
|
||||
ref_text = _ref_audio_cache[audio_hash]
|
||||
ref_text = _ref_text_cache[audio_hash]
|
||||
else:
|
||||
show_info("No reference text provided, transcribing reference audio...")
|
||||
ref_text = transcribe(ref_audio)
|
||||
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
|
||||
_ref_audio_cache[audio_hash] = ref_text
|
||||
_ref_text_cache[audio_hash] = ref_text
|
||||
else:
|
||||
show_info("Using custom reference text...")
|
||||
|
||||
@@ -384,7 +399,7 @@ def infer_process(
|
||||
):
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from f5_tts.model.cfm import CFM
|
||||
|
||||
from f5_tts.model.backbones.unett import UNetT
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from f5_tts.model.backbones.mmdit import MMDiT
|
||||
|
||||
from f5_tts.model.backbones.unett import UNetT
|
||||
from f5_tts.model.cfm import CFM
|
||||
from f5_tts.model.trainer import Trainer
|
||||
|
||||
|
||||
|
||||
@@ -10,19 +10,18 @@ d - dimension
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
AdaLayerNorm_Final,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNorm_Final,
|
||||
TimestepEmbedding,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
|
||||
|
||||
@@ -30,11 +29,16 @@ from f5_tts.model.modules import (
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
||||
def __init__(
|
||||
self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2
|
||||
):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
||||
self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder)
|
||||
if average_upsampling:
|
||||
assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True"
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
@@ -46,11 +50,46 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
|
||||
batch, text_len, text_dim = text.shape
|
||||
|
||||
if audio_mask is None:
|
||||
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
|
||||
valid_mask = audio_mask & text_mask
|
||||
audio_lens = audio_mask.sum(dim=1) # [batch]
|
||||
valid_lens = valid_mask.sum(dim=1) # [batch]
|
||||
|
||||
upsampled_text = torch.zeros_like(text)
|
||||
|
||||
for i in range(batch):
|
||||
audio_len = audio_lens[i].item()
|
||||
valid_len = valid_lens[i].item()
|
||||
|
||||
if valid_len == 0:
|
||||
continue
|
||||
|
||||
valid_ind = torch.where(valid_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
|
||||
|
||||
base_repeat = audio_len // valid_len
|
||||
remainder = audio_len % valid_len
|
||||
|
||||
indices = []
|
||||
for j in range(valid_len):
|
||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||
indices.extend([j] * repeat_count)
|
||||
|
||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||
upsampled = valid_data[indices] # [audio_len, text_dim]
|
||||
|
||||
upsampled_text[i, :audio_len, :] = upsampled
|
||||
|
||||
return upsampled_text
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
@@ -62,10 +101,7 @@ class TextEmbedding(nn.Module):
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
# sinus pos emb
|
||||
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||
text_pos_embed = self.freqs_cis[pos_idx]
|
||||
text = text + text_pos_embed
|
||||
text = text + self.freqs_cis[:seq_len, :]
|
||||
|
||||
# convnextv2 blocks
|
||||
if self.mask_padding:
|
||||
@@ -76,6 +112,9 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
text = self.text_blocks(text)
|
||||
|
||||
if self.average_upsampling:
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@@ -114,9 +153,12 @@ class DiT(nn.Module):
|
||||
text_num_embeds=256,
|
||||
text_dim=None,
|
||||
text_mask_padding=True,
|
||||
text_embedding_average_upsampling=False,
|
||||
qk_norm=None,
|
||||
conv_layers=0,
|
||||
pe_attn_head=None,
|
||||
attn_backend="torch", # "torch" | "flash_attn"
|
||||
attn_mask_enabled=False,
|
||||
long_skip_connection=False,
|
||||
checkpoint_activations=False,
|
||||
):
|
||||
@@ -126,7 +168,11 @@ class DiT(nn.Module):
|
||||
if text_dim is None:
|
||||
text_dim = mel_dim
|
||||
self.text_embed = TextEmbedding(
|
||||
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
||||
text_num_embeds,
|
||||
text_dim,
|
||||
mask_padding=text_mask_padding,
|
||||
average_upsampling=text_embedding_average_upsampling,
|
||||
conv_layers=conv_layers,
|
||||
)
|
||||
self.text_cond, self.text_uncond = None, None # text cache
|
||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||
@@ -146,6 +192,8 @@ class DiT(nn.Module):
|
||||
dropout=dropout,
|
||||
qk_norm=qk_norm,
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
@@ -179,6 +227,48 @@ class DiT(nn.Module):
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def get_input_embed(
|
||||
self,
|
||||
x, # b n d
|
||||
cond, # b n d
|
||||
text, # b nt
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||
):
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
if audio_mask is None:
|
||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
|
||||
else:
|
||||
batch = x.shape[0]
|
||||
seq_lens = audio_mask.sum(dim=1)
|
||||
text_embed_list = []
|
||||
for i in range(batch):
|
||||
text_embed_i = self.text_embed(
|
||||
text[i].unsqueeze(0),
|
||||
seq_lens[i].item(),
|
||||
drop_text=drop_text,
|
||||
audio_mask=audio_mask,
|
||||
)
|
||||
text_embed_list.append(text_embed_i[0])
|
||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||
if cache:
|
||||
if drop_text:
|
||||
self.text_uncond = text_embed
|
||||
else:
|
||||
self.text_cond = text_embed
|
||||
|
||||
if cache:
|
||||
if drop_text:
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
text_embed = self.text_cond
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
return x
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
@@ -188,10 +278,11 @@ class DiT(nn.Module):
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
cache: bool = False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
@@ -199,18 +290,20 @@ class DiT(nn.Module):
|
||||
|
||||
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
||||
t = self.time_embed(time)
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
||||
text_embed = self.text_cond
|
||||
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
||||
x_cond = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask
|
||||
)
|
||||
x_uncond = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask
|
||||
)
|
||||
x = torch.cat((x_cond, x_uncond), dim=0)
|
||||
t = torch.cat((t, t), dim=0)
|
||||
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
x = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask
|
||||
)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
|
||||
@@ -11,16 +11,15 @@ from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
AdaLayerNorm_Final,
|
||||
ConvPositionEmbedding,
|
||||
MMDiTBlock,
|
||||
AdaLayerNorm_Final,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
@@ -142,26 +141,15 @@ class MMDiT(nn.Module):
|
||||
nn.init.constant_(self.proj_out.weight, 0)
|
||||
nn.init.constant_(self.proj_out.bias, 0)
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
def get_input_embed(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
x, # b n d
|
||||
cond, # b n d
|
||||
text, # b nt
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
):
|
||||
batch = x.shape[0]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
@@ -175,6 +163,41 @@ class MMDiT(nn.Module):
|
||||
c = self.text_embed(text, drop_text=drop_text)
|
||||
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
return x, c
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
cache: bool = False,
|
||||
):
|
||||
batch = x.shape[0]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
||||
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
|
||||
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
|
||||
x = torch.cat((x_cond, x_uncond), dim=0)
|
||||
c = torch.cat((c_cond, c_uncond), dim=0)
|
||||
t = torch.cat((t, t), dim=0)
|
||||
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
||||
else:
|
||||
x, c = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
|
||||
)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
text_len = text.shape[1]
|
||||
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
@@ -8,24 +8,24 @@ d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from x_transformers import RMSNorm
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
FeedForward,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,6 +120,8 @@ class UNetT(nn.Module):
|
||||
qk_norm=None,
|
||||
conv_layers=0,
|
||||
pe_attn_head=None,
|
||||
attn_backend="torch", # "torch" | "flash_attn"
|
||||
attn_mask_enabled=False,
|
||||
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -150,7 +152,11 @@ class UNetT(nn.Module):
|
||||
|
||||
attn_norm = RMSNorm(dim)
|
||||
attn = Attention(
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
processor=AttnProcessor(
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
@@ -178,26 +184,16 @@ class UNetT(nn.Module):
|
||||
self.norm_out = RMSNorm(dim)
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
def get_input_embed(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
x, # b n d
|
||||
cond, # b n d
|
||||
text, # b nt
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
seq_len = x.shape[1]
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
@@ -209,8 +205,41 @@ class UNetT(nn.Module):
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
return x
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
cache: bool = False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
||||
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
|
||||
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
|
||||
x = torch.cat((x_cond, x_uncond), dim=0)
|
||||
t = torch.cat((t, t), dim=0)
|
||||
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
||||
else:
|
||||
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
|
||||
|
||||
# postfix time t to input x, [b n d] -> [b n+1 d]
|
||||
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
||||
if mask is not None:
|
||||
|
||||
@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
|
||||
from f5_tts.model.utils import (
|
||||
default,
|
||||
exists,
|
||||
get_epss_timesteps,
|
||||
lens_to_mask,
|
||||
list_str_to_idx,
|
||||
list_str_to_tensor,
|
||||
@@ -92,6 +93,7 @@ class CFM(nn.Module):
|
||||
seed: int | None = None,
|
||||
max_duration=4096,
|
||||
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
||||
use_epss=True,
|
||||
no_ref_audio=False,
|
||||
duplicate_test=False,
|
||||
t_inter=0.1,
|
||||
@@ -160,16 +162,31 @@ class CFM(nn.Module):
|
||||
# at each step, conditioning is fixed
|
||||
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
||||
|
||||
# predict flow
|
||||
pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
|
||||
)
|
||||
# predict flow (cond)
|
||||
if cfg_strength < 1e-5:
|
||||
pred = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
drop_audio_cond=False,
|
||||
drop_text=False,
|
||||
cache=True,
|
||||
)
|
||||
return pred
|
||||
|
||||
null_pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
|
||||
# predict flow (cond and uncond), for classifier-free guidance
|
||||
pred_cfg = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
cfg_infer=True,
|
||||
cache=True,
|
||||
)
|
||||
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
|
||||
return pred + (pred - null_pred) * cfg_strength
|
||||
|
||||
# noise input
|
||||
@@ -190,7 +207,10 @@ class CFM(nn.Module):
|
||||
y0 = (1 - t_start) * y0 + t_start * test_cond
|
||||
steps = int(steps * (1 - t_start))
|
||||
|
||||
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
||||
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
|
||||
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
||||
else:
|
||||
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
||||
if sway_sampling_coef is not None:
|
||||
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||
|
||||
@@ -232,10 +252,9 @@ class CFM(nn.Module):
|
||||
assert text.shape[0] == batch
|
||||
|
||||
# lens and mask
|
||||
if not exists(lens):
|
||||
if not exists(lens): # if lens not acquired by trainer from collate_fn
|
||||
lens = torch.full((batch,), seq_len, device=device)
|
||||
|
||||
mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch
|
||||
mask = lens_to_mask(lens, length=seq_len)
|
||||
|
||||
# get a random span to mask out for training conditionally
|
||||
frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
|
||||
@@ -270,10 +289,9 @@ class CFM(nn.Module):
|
||||
else:
|
||||
drop_text = False
|
||||
|
||||
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
||||
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
|
||||
pred = self.transformer(
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
|
||||
)
|
||||
|
||||
# flow matching loss
|
||||
|
||||
@@ -312,7 +312,7 @@ def collate_fn(batch):
|
||||
max_mel_length = mel_lengths.amax()
|
||||
|
||||
padded_mel_specs = []
|
||||
for spec in mel_specs: # TODO. maybe records mask for attention here
|
||||
for spec in mel_specs:
|
||||
padding = (0, max_mel_length - spec.size(-1))
|
||||
padded_spec = F.pad(spec, padding, value=0)
|
||||
padded_mel_specs.append(padded_spec)
|
||||
@@ -324,7 +324,7 @@ def collate_fn(batch):
|
||||
|
||||
return dict(
|
||||
mel=mel_specs,
|
||||
mel_lengths=mel_lengths,
|
||||
mel_lengths=mel_lengths, # records for padding mask
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
from f5_tts.model.utils import is_package_available
|
||||
|
||||
|
||||
# raw wav to mel spec
|
||||
|
||||
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
@@ -417,9 +420,9 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b n d"] = None, # context c # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
c: float["b n d"] = None, # context c
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
@@ -431,19 +434,30 @@ class Attention(nn.Module):
|
||||
|
||||
# Attention processor
|
||||
|
||||
if is_package_available("flash_attn"):
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
||||
attn_backend: str = "torch", # "torch" or "flash_attn"
|
||||
attn_mask_enabled: bool = True,
|
||||
):
|
||||
if attn_backend == "flash_attn":
|
||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||
|
||||
self.pe_attn_head = pe_attn_head
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_mask_enabled = attn_mask_enabled
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = x.shape[0]
|
||||
@@ -479,16 +493,40 @@ class AttnProcessor:
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
if self.attn_backend == "torch":
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
elif self.attn_backend == "flash_attn":
|
||||
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
|
||||
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
||||
value, _, _, _, _ = unpad_input(value, mask)
|
||||
x = flash_attn_varlen_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_cu_seqlens,
|
||||
k_cu_seqlens,
|
||||
q_max_seqlen_in_batch,
|
||||
k_max_seqlen_in_batch,
|
||||
)
|
||||
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
else:
|
||||
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@@ -514,9 +552,9 @@ class JointAttnProcessor:
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
c: float["b nt d"] = None, # context c, here text
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.FloatTensor:
|
||||
@@ -608,12 +646,27 @@ class JointAttnProcessor:
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads,
|
||||
dim_head,
|
||||
ff_mult=4,
|
||||
dropout=0.1,
|
||||
qk_norm=None,
|
||||
pe_attn_head=None,
|
||||
attn_backend="torch", # "torch" or "flash_attn"
|
||||
attn_mask_enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
processor=AttnProcessor(
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
|
||||
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: float["b"]): # noqa: F821
|
||||
def forward(self, timestep: float["b"]):
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
|
||||
@@ -19,6 +19,7 @@ from f5_tts.model import CFM
|
||||
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
|
||||
from f5_tts.model.utils import default, exists
|
||||
|
||||
|
||||
# trainer
|
||||
|
||||
|
||||
@@ -148,7 +149,7 @@ class Trainer:
|
||||
if self.is_main:
|
||||
checkpoint = dict(
|
||||
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
||||
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
||||
optimizer_state_dict=self.optimizer.state_dict(),
|
||||
ema_model_state_dict=self.ema_model.state_dict(),
|
||||
scheduler_state_dict=self.scheduler.state_dict(),
|
||||
update=update,
|
||||
@@ -241,7 +242,7 @@ class Trainer:
|
||||
del checkpoint["model_state_dict"][key]
|
||||
|
||||
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
||||
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
if self.scheduler:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
update = checkpoint["update"]
|
||||
|
||||
@@ -5,11 +5,10 @@ import random
|
||||
from collections import defaultdict
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import jieba
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
import torch
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
# seed everything
|
||||
@@ -36,6 +35,16 @@ def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
def is_package_available(package_name: str) -> bool:
|
||||
try:
|
||||
import importlib
|
||||
|
||||
package_exists = importlib.util.find_spec(package_name) is not None
|
||||
return package_exists
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# tensor helpers
|
||||
|
||||
|
||||
@@ -190,3 +199,22 @@ def repetition_found(text, length=2, tolerance=10):
|
||||
if count > tolerance:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# get the empirically pruned step for sampling
|
||||
|
||||
|
||||
def get_epss_timesteps(n, device, dtype):
|
||||
dt = 1 / 32
|
||||
predefined_timesteps = {
|
||||
5: [0, 2, 4, 8, 16, 32],
|
||||
6: [0, 2, 4, 6, 8, 16, 32],
|
||||
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
||||
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
||||
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
}
|
||||
t = predefined_timesteps.get(n, [])
|
||||
if not t:
|
||||
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
|
||||
return dt * torch.tensor(t, device=device, dtype=dtype)
|
||||
|
||||
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
3
src/f5_tts/runtime/triton_trtllm/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# runtime/triton_trtllm related
|
||||
model.cache
|
||||
model_repo/
|
||||
@@ -1,47 +1,79 @@
|
||||
## Triton Inference Serving Best Practice for F5-TTS
|
||||
|
||||
### Quick Start
|
||||
Directly launch the service using docker compose.
|
||||
### Setup
|
||||
#### Option 1: Quick Start
|
||||
```sh
|
||||
# TODO: support F5TTS_v1_Base
|
||||
MODEL=F5TTS_Base docker compose up
|
||||
# Directly launch the service using docker compose
|
||||
MODEL=F5TTS_v1_Base docker compose up
|
||||
```
|
||||
|
||||
### Build Image
|
||||
Build the docker image from scratch.
|
||||
#### Option 2: Build from scratch
|
||||
```sh
|
||||
# Build the docker image
|
||||
docker build . -f Dockerfile.server -t soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Create Docker Container
|
||||
```sh
|
||||
# Create Docker Container
|
||||
your_mount_dir=/mnt:/mnt
|
||||
docker run -it --name "f5-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-f5-tts:24.12
|
||||
```
|
||||
|
||||
### Export Models to TensorRT-LLM and Launch Server
|
||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).
|
||||
### Build TensorRT-LLM Engines and Launch Server
|
||||
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/whisper).
|
||||
```sh
|
||||
bash run.sh 0 4 F5TTS_Base
|
||||
# F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
|
||||
bash run.sh 0 4 F5TTS_v1_Base
|
||||
```
|
||||
> [!NOTE]
|
||||
> If use custom checkpoint, set `ckpt_file` and `vocab_file` in `run.sh`.
|
||||
> Remember to used matched model version (`F5TTS_v1_*` for v1, `F5TTS_*` for v0).
|
||||
>
|
||||
> If use checkpoint of different structure, see `scripts/convert_checkpoint.py`, and perform modification if necessary.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> If train or finetune with fp32, add `--dtype float32` flag when converting checkpoint in `run.sh` phase 1.
|
||||
|
||||
### HTTP Client
|
||||
```sh
|
||||
python3 client_http.py
|
||||
```
|
||||
|
||||
### Benchmark using Dataset
|
||||
### Benchmarking
|
||||
#### Using Client-Server Mode
|
||||
```sh
|
||||
# bash run.sh 5 5 F5TTS_v1_Base
|
||||
num_task=2
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts
|
||||
```
|
||||
|
||||
### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs.
|
||||
#### Using Offline TRT-LLM Mode
|
||||
```sh
|
||||
# bash run.sh 7 7 F5TTS_v1_Base
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./tests/benchmark_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
```
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF |
|
||||
|-------|-------------|----------------|-------|
|
||||
| F5-TTS Base (Vocos) | 1 | 253 ms | 0.0394|
|
||||
### Benchmark Results
|
||||
Decoding on a single L20 GPU, using 26 different prompt_audio & target_text pairs, 16 NFE.
|
||||
|
||||
| Model | Concurrency | Avg Latency | RTF | Mode |
|
||||
|---------------------|----------------|-------------|--------|-----------------|
|
||||
| F5-TTS Base (Vocos) | 2 | 253 ms | 0.0394 | Client-Server |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.0402 | Offline TRT-LLM |
|
||||
| F5-TTS Base (Vocos) | 1 (Batch_size) | - | 0.1467 | Offline Pytorch |
|
||||
|
||||
### Credits
|
||||
1. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
1. [Yuekai Zhang](https://github.com/yuekaizhang)
|
||||
2. [F5-TTS-TRTLLM](https://github.com/Bigfishering/f5-tts-trtllm)
|
||||
|
||||
473
src/f5_tts/runtime/triton_trtllm/benchmark.py
Normal file
473
src/f5_tts/runtime/triton_trtllm/benchmark.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
||||
""" Example Usage
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $CKPT_DIR/$model/model_1200000.pt \
|
||||
--vocab-file $CKPT_DIR/$model/vocab.txt \
|
||||
--vocoder-trt-engine-path $vocoder_trt_engine_path \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import datasets
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from tensorrt_llm._utils import trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session, TensorInfo
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from vocos import Vocos
|
||||
|
||||
|
||||
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/")
|
||||
|
||||
from f5_tts.eval.utils_eval import padded_mel_batch
|
||||
from f5_tts.model.modules import get_vocos_mel_spectrogram
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx
|
||||
|
||||
|
||||
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="extract speech code")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||
help="huggingface dataset split name",
|
||||
)
|
||||
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result")
|
||||
parser.add_argument(
|
||||
"--vocab-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="vocab file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="model path, to load text embedding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tllm-model-dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="tllm model dir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
required=True,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader")
|
||||
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader")
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
default="vocos",
|
||||
type=str,
|
||||
help="vocoder name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocoder-trt-engine-path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="vocoder trt engine path",
|
||||
)
|
||||
parser.add_argument("--enable-warmup", action="store_true")
|
||||
parser.add_argument("--remove-input-padding", action="store_true")
|
||||
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance")
|
||||
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("data_collator")
|
||||
target_sample_rate = 24000
|
||||
target_rms = 0.1
|
||||
(
|
||||
ids,
|
||||
ref_rms_list,
|
||||
ref_mel_list,
|
||||
ref_mel_len_list,
|
||||
estimated_reference_target_mel_len,
|
||||
reference_target_texts_list,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for i, item in enumerate(batch):
|
||||
item_id, prompt_text, target_text = (
|
||||
item["id"],
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
)
|
||||
ids.append(item_id)
|
||||
reference_target_texts_list.append(prompt_text + target_text)
|
||||
|
||||
ref_audio_org, ref_sr = (
|
||||
item["prompt_audio"]["array"],
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||
ref_rms_list.append(ref_rms)
|
||||
if ref_rms < target_rms:
|
||||
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}")
|
||||
ref_audio = ref_audio.to("cuda")
|
||||
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0)
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
ref_mel_len = ref_mel.shape[-1]
|
||||
assert ref_mel.shape[0] == 100
|
||||
|
||||
ref_mel_list.append(ref_mel)
|
||||
ref_mel_len_list.append(ref_mel_len)
|
||||
|
||||
estimated_reference_target_mel_len.append(
|
||||
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8"))))
|
||||
)
|
||||
|
||||
ref_mel_batch = padded_mel_batch(ref_mel_list)
|
||||
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list)
|
||||
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map)
|
||||
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return {
|
||||
"ids": ids,
|
||||
"ref_rms_list": ref_rms_list,
|
||||
"ref_mel_batch": ref_mel_batch,
|
||||
"ref_mel_len_batch": ref_mel_len_batch,
|
||||
"text_pad_sequence": text_pad_sequence,
|
||||
"estimated_reference_target_mel_len": estimated_reference_target_mel_len,
|
||||
}
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
# Initialize process group with explicit device IDs
|
||||
dist.init_process_group(
|
||||
"nccl",
|
||||
)
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def load_vocoder(
|
||||
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None
|
||||
):
|
||||
if vocoder_name == "vocos":
|
||||
if vocoder_trt_engine_path is not None:
|
||||
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path)
|
||||
else:
|
||||
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
||||
if is_local:
|
||||
print(f"Load vocos from local path {local_path}")
|
||||
config_path = f"{local_path}/config.yaml"
|
||||
model_path = f"{local_path}/pytorch_model.bin"
|
||||
else:
|
||||
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
||||
repo_id = "charactr/vocos-mel-24khz"
|
||||
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
||||
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
||||
vocoder = Vocos.from_hparams(config_path)
|
||||
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
||||
from vocos.feature_extractors import EncodecFeatures
|
||||
|
||||
if isinstance(vocoder.feature_extractor, EncodecFeatures):
|
||||
encodec_parameters = {
|
||||
"feature_extractor.encodec." + key: value
|
||||
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
||||
}
|
||||
state_dict.update(encodec_parameters)
|
||||
vocoder.load_state_dict(state_dict)
|
||||
vocoder = vocoder.eval().to(device)
|
||||
elif vocoder_name == "bigvgan":
|
||||
raise NotImplementedError("BigVGAN is not implemented yet")
|
||||
return vocoder
|
||||
|
||||
|
||||
class VocosTensorRT:
|
||||
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None):
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
|
||||
logger.info(f"Loading vocoder engine from {engine_path}")
|
||||
self.engine_path = engine_path
|
||||
with open(engine_path, "rb") as f:
|
||||
engine_buffer = f.read()
|
||||
self.session = Session.from_serialized_engine(engine_buffer)
|
||||
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream
|
||||
|
||||
def decode(self, mels):
|
||||
mels = mels.contiguous()
|
||||
inputs = {"mel": mels}
|
||||
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)])
|
||||
outputs = {
|
||||
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info
|
||||
}
|
||||
ok = self.session.run(inputs, outputs, self.stream)
|
||||
|
||||
assert ok, "Runtime execution failed for vae session"
|
||||
|
||||
samples = outputs["waveform"]
|
||||
return samples
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom")
|
||||
|
||||
tllm_model_dir = args.tllm_model_dir
|
||||
with open(os.path.join(tllm_model_dir, "config.json")) as f:
|
||||
tllm_model_config = json.load(f)
|
||||
if args.backend_type == "trt":
|
||||
model = F5TTS(
|
||||
tllm_model_config,
|
||||
debug_mode=False,
|
||||
tllm_model_dir=tllm_model_dir,
|
||||
model_path=args.model_path,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
from f5_tts.infer.utils_infer import load_model
|
||||
from f5_tts.model import DiT
|
||||
|
||||
pretrained_config = tllm_model_config["pretrained_config"]
|
||||
pt_model_config = dict(
|
||||
dim=pretrained_config["hidden_size"],
|
||||
depth=pretrained_config["num_hidden_layers"],
|
||||
heads=pretrained_config["num_attention_heads"],
|
||||
ff_mult=pretrained_config["ff_mult"],
|
||||
text_dim=pretrained_config["text_dim"],
|
||||
text_mask_padding=pretrained_config["text_mask_padding"],
|
||||
conv_layers=pretrained_config["conv_layers"],
|
||||
pe_attn_head=pretrained_config["pe_attn_head"],
|
||||
)
|
||||
model = load_model(DiT, pt_model_config, args.model_path)
|
||||
|
||||
vocoder = load_vocoder(
|
||||
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path
|
||||
)
|
||||
|
||||
dataset = load_dataset(
|
||||
"yuekai/seed_tts",
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
def add_estimated_duration(example):
|
||||
prompt_audio_len = example["prompt_audio"]["array"].shape[0]
|
||||
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"])
|
||||
estimated_duration = prompt_audio_len * scale_factor
|
||||
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"]
|
||||
return example
|
||||
|
||||
dataset = dataset.map(add_estimated_duration)
|
||||
dataset = dataset.sort("estimated_duration", reverse=True)
|
||||
if args.use_perf:
|
||||
# dataset_list = [dataset.select(range(1)) for i in range(16)] # seq_len 1000
|
||||
dataset_list_short = [dataset.select([24]) for i in range(8)] # seq_len 719
|
||||
# dataset_list_long = [dataset.select([23]) for i in range(8)] # seq_len 2002
|
||||
# dataset = datasets.concatenate_datasets(dataset_list_short + dataset_list_long)
|
||||
dataset = datasets.concatenate_datasets(dataset_list_short)
|
||||
if world_size > 1:
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
else:
|
||||
# This would disable shuffling
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf),
|
||||
)
|
||||
|
||||
total_steps = len(dataset)
|
||||
|
||||
if args.enable_warmup:
|
||||
for batch in dataloader:
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||
if args.backend_type == "trt":
|
||||
_ = model.sample(
|
||||
text_pad_seq,
|
||||
cond_pad_seq,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
with torch.inference_mode():
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
steps=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
decoding_time = 0
|
||||
vocoder_time = 0
|
||||
total_duration = 0
|
||||
if args.use_perf:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
total_decoding_time = time.time()
|
||||
for batch in dataloader:
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_push("data sample")
|
||||
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device)
|
||||
text_pad_seq = batch["text_pad_sequence"].to(device)
|
||||
total_mel_lens = batch["estimated_reference_target_mel_len"]
|
||||
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0))
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
if args.backend_type == "trt":
|
||||
generated, cost_time = model.sample(
|
||||
text_pad_seq,
|
||||
cond_pad_seq,
|
||||
ref_mel_lens,
|
||||
total_mel_lens,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
use_perf=args.use_perf,
|
||||
)
|
||||
elif args.backend_type == "pytorch":
|
||||
total_mel_lens = torch.tensor(total_mel_lens, device=device)
|
||||
with torch.inference_mode():
|
||||
start_time = time.time()
|
||||
generated, _ = model.sample(
|
||||
cond=ref_mels,
|
||||
text=text_pad_seq,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=32,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=-1,
|
||||
)
|
||||
cost_time = time.time() - start_time
|
||||
decoding_time += cost_time
|
||||
vocoder_start_time = time.time()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24000
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
if args.vocoder == "vocos":
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_push("vocoder decode")
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
if args.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
|
||||
if batch["ref_rms_list"][i] < target_rms:
|
||||
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms
|
||||
|
||||
utt = batch["ids"][i]
|
||||
torchaudio.save(
|
||||
f"{args.output_dir}/{utt}.wav",
|
||||
generated_wave,
|
||||
target_sample_rate,
|
||||
)
|
||||
total_duration += generated_wave.shape[1] / target_sample_rate
|
||||
vocoder_time += time.time() - vocoder_start_time
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
total_decoding_time = time.time() - total_decoding_time
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
rtf = total_decoding_time / total_duration
|
||||
s = f"RTF: {rtf:.4f}\n"
|
||||
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||||
s += f"({total_duration / 3600:.2f} hours)\n"
|
||||
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n"
|
||||
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n"
|
||||
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n"
|
||||
s += f"batch size: {args.batch_size}\n"
|
||||
print(s)
|
||||
|
||||
with open(f"{args.output_dir}/rtf.txt", "w") as f:
|
||||
f.write(s)
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -30,21 +30,11 @@ python3 client_grpc.py \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name test_zh \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
|
||||
# For offline Spark-TTS-0.5B
|
||||
python3 client_grpc.py \
|
||||
--server-addr localhost \
|
||||
--model-name spark_tts \
|
||||
--num-tasks $num_task \
|
||||
--huggingface-dataset yuekai/seed_tts \
|
||||
--split-name wenetspeech4tts \
|
||||
--log-dir ./log_concurrent_tasks_${num_task}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
@@ -177,8 +167,7 @@ def get_args():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -207,7 +196,7 @@ def get_args():
|
||||
"--log-dir",
|
||||
type=str,
|
||||
required=False,
|
||||
default="./tmp",
|
||||
default="./tests/client_grpc",
|
||||
help="log directory",
|
||||
)
|
||||
|
||||
@@ -221,8 +210,8 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
@@ -231,8 +220,7 @@ def load_audio(wav_path, target_sample_rate=16000):
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
|
||||
waveform = resample(waveform, num_samples)
|
||||
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
@@ -245,7 +233,7 @@ async def send(
|
||||
model_name: str,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
save_sample_rate: int = 24000,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -255,7 +243,7 @@ async def send(
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
|
||||
duration = len(waveform) / sample_rate
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
|
||||
@@ -311,8 +299,9 @@ async def send(
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||
|
||||
latency_data.append((end, estimated_target_duration))
|
||||
total_duration += estimated_target_duration
|
||||
actual_duration = len(audio) / save_sample_rate
|
||||
latency_data.append((end, actual_duration))
|
||||
total_duration += actual_duration
|
||||
|
||||
return total_duration, latency_data
|
||||
|
||||
@@ -417,7 +406,7 @@ async def main():
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
||||
save_sample_rate=24000,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
@@ -23,10 +23,12 @@
|
||||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
@@ -64,33 +66,32 @@ def get_args():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="f5_tts",
|
||||
choices=["f5_tts", "spark_tts"],
|
||||
help="triton model_repo module name to request",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-audio",
|
||||
type=str,
|
||||
default="output.wav",
|
||||
default="tests/client_http.wav",
|
||||
help="Path to save the output audio",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_request(
|
||||
samples,
|
||||
waveform,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
sample_rate=24000,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(samples.shape) == 1, "samples should be 1D"
|
||||
lengths = np.array([[len(samples)]], dtype=np.int32)
|
||||
samples = samples.reshape(1, -1).astype(np.float32)
|
||||
assert len(waveform.shape) == 1, "waveform should be 1D"
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
waveform = waveform.reshape(1, -1).astype(np.float32)
|
||||
|
||||
data = {
|
||||
"inputs": [
|
||||
{"name": "reference_wav", "shape": samples.shape, "datatype": "FP32", "data": samples.tolist()},
|
||||
{"name": "reference_wav", "shape": waveform.shape, "datatype": "FP32", "data": waveform.tolist()},
|
||||
{
|
||||
"name": "reference_wav_len",
|
||||
"shape": lengths.shape,
|
||||
@@ -105,19 +106,18 @@ def prepare_request(
|
||||
return data
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
samples = wav_path["array"]
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
else:
|
||||
samples, sample_rate = sf.read(wav_path)
|
||||
waveform, sample_rate = sf.read(wav_path)
|
||||
if sample_rate != target_sample_rate:
|
||||
from scipy.signal import resample
|
||||
|
||||
num_samples = int(len(samples) * (target_sample_rate / sample_rate))
|
||||
samples = resample(samples, num_samples)
|
||||
return samples, target_sample_rate
|
||||
waveform = resample(waveform, int(len(waveform) * (target_sample_rate / sample_rate)))
|
||||
return waveform, target_sample_rate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -127,11 +127,11 @@ if __name__ == "__main__":
|
||||
server_url = f"http://{server_url}"
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
samples, sr = load_audio(args.reference_audio)
|
||||
assert sr == 16000, "sample rate hardcoded in server"
|
||||
waveform, sr = load_audio(args.reference_audio)
|
||||
assert sr == 24000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(samples, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
waveform = np.array(waveform, dtype=np.float32)
|
||||
data = prepare_request(waveform, args.reference_text, args.target_text)
|
||||
|
||||
rsp = requests.post(
|
||||
url, headers={"Content-Type": "application/json"}, json=data, verify=False, params={"request_id": "0"}
|
||||
@@ -139,4 +139,5 @@ if __name__ == "__main__":
|
||||
result = rsp.json()
|
||||
audio = result["outputs"][0]["data"]
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
os.makedirs(os.path.dirname(args.output_audio), exist_ok=True)
|
||||
sf.write(args.output_audio, audio, 24000, "PCM_16")
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import tensorrt as trt
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from functools import wraps
|
||||
from typing import List, Optional
|
||||
|
||||
import tensorrt as trt
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.runtime.session import Session
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
@@ -33,26 +33,35 @@ def remove_tensor_padding(input_tensor, input_tensor_lengths=None):
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2, precompute_max_pos=4096):
|
||||
def __init__(
|
||||
self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2, precompute_max_pos=4096
|
||||
):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
self.mask_padding = mask_padding
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||
|
||||
def forward(self, text):
|
||||
# only keep tensors with value not -1
|
||||
text_mask = text != -1
|
||||
text_pad_cut_off_index = text_mask.sum(dim=1).max()
|
||||
def forward(self, text, seq_len, drop_text=False):
|
||||
text = text + 1
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
text = text + self.freqs_cis[:seq_len, :]
|
||||
if self.mask_padding:
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
text = text.masked_fill(text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0)
|
||||
else:
|
||||
text = self.text_blocks(text)
|
||||
|
||||
text = text[:, :text_pad_cut_off_index]
|
||||
text = self.text_embed(text)
|
||||
text = text + self.freqs_cis[: text.shape[1], :]
|
||||
for block in self.text_blocks:
|
||||
text = block(text)
|
||||
# padding text to the original length
|
||||
# text shape: B,seq_len,C
|
||||
# pad at the second dimension
|
||||
text = F.pad(text, (0, 0, 0, text_mask.shape[1] - text.shape[1], 0, 0), value=0)
|
||||
return text
|
||||
|
||||
|
||||
@@ -113,20 +122,33 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_path, use_ema=True):
|
||||
checkpoint = torch.load(ckpt_path, weights_only=True)
|
||||
def get_text_embed_dict(ckpt_path, use_ema=True):
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
checkpoint = load_file(ckpt_path)
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
|
||||
if use_ema:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"ema_model_state_dict": checkpoint}
|
||||
checkpoint["model_state_dict"] = {
|
||||
k.replace("ema_model.", ""): v
|
||||
for k, v in checkpoint["ema_model_state_dict"].items()
|
||||
if k not in ["initted", "step"]
|
||||
}
|
||||
dict_state = checkpoint["model_state_dict"]
|
||||
else:
|
||||
if ckpt_type == "safetensors":
|
||||
checkpoint = {"model_state_dict": checkpoint}
|
||||
model_params = checkpoint["model_state_dict"]
|
||||
|
||||
text_embed_dict = {}
|
||||
for key in dict_state.keys():
|
||||
for key in model_params.keys():
|
||||
# transformer.text_embed.text_embed.weight -> text_embed.weight
|
||||
if "text_embed" in key:
|
||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = dict_state[key]
|
||||
text_embed_dict[key.replace("transformer.text_embed.", "")] = model_params[key]
|
||||
return text_embed_dict
|
||||
|
||||
|
||||
@@ -195,18 +217,16 @@ class F5TTS(object):
|
||||
|
||||
self.max_mel_len = 4096
|
||||
self.text_embedding = TextEmbedding(
|
||||
text_num_embeds=vocab_size, text_dim=512, conv_layers=4, precompute_max_pos=self.max_mel_len
|
||||
text_num_embeds=vocab_size,
|
||||
text_dim=config["pretrained_config"]["text_dim"],
|
||||
mask_padding=config["pretrained_config"]["text_mask_padding"],
|
||||
conv_layers=config["pretrained_config"]["conv_layers"],
|
||||
precompute_max_pos=self.max_mel_len,
|
||||
).to(self.device)
|
||||
self.text_embedding.load_state_dict(load_checkpoint(model_path), strict=True)
|
||||
self.text_embedding.load_state_dict(get_text_embed_dict(model_path), strict=True)
|
||||
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
# self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
self.n_mel_channels = config["pretrained_config"]["mel_dim"]
|
||||
self.head_dim = config["pretrained_config"]["dim_head"]
|
||||
self.base_rescale_factor = 1.0
|
||||
self.interpolation_factor = 1.0
|
||||
base = 10000.0 * self.base_rescale_factor ** (self.head_dim / (self.head_dim - 2))
|
||||
@@ -215,14 +235,23 @@ class F5TTS(object):
|
||||
self.freqs = freqs.repeat_interleave(2, dim=-1).unsqueeze(0)
|
||||
self.rope_cos = self.freqs.cos().half()
|
||||
self.rope_sin = self.freqs.sin().half()
|
||||
self.nfe_steps = 16
|
||||
t = torch.linspace(0, 1, self.nfe_steps + 1, dtype=torch.float32)
|
||||
time_step = t + (-1.0) * (torch.cos(torch.pi * 0.5 * t) - 1 + t)
|
||||
|
||||
self.nfe_steps = 32
|
||||
epss = {
|
||||
5: [0, 2, 4, 8, 16, 32],
|
||||
6: [0, 2, 4, 6, 8, 16, 32],
|
||||
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
||||
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
||||
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
}
|
||||
t = 1 / 32 * torch.tensor(epss.get(self.nfe_steps, list(range(self.nfe_steps + 1))), dtype=torch.float32)
|
||||
time_step = 1 - torch.cos(torch.pi * t / 2)
|
||||
delta_t = torch.diff(time_step)
|
||||
# WAR: hard coding 256 here
|
||||
tmp_dim = 256
|
||||
time_expand = torch.zeros((1, self.nfe_steps, tmp_dim), dtype=torch.float32)
|
||||
half_dim = tmp_dim // 2
|
||||
|
||||
freq_embed_dim = 256 # Warning: hard coding 256 here
|
||||
time_expand = torch.zeros((1, self.nfe_steps, freq_embed_dim), dtype=torch.float32)
|
||||
half_dim = freq_embed_dim // 2
|
||||
emb_factor = math.log(10000) / (half_dim - 1)
|
||||
emb_factor = 1000.0 * torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
|
||||
for i in range(self.nfe_steps):
|
||||
@@ -345,7 +374,7 @@ class F5TTS(object):
|
||||
def sample(
|
||||
self,
|
||||
text_pad_sequence: torch.Tensor,
|
||||
ref_mel_batch: torch.Tensor,
|
||||
cond_pad_sequence: torch.Tensor,
|
||||
ref_mel_len_batch: torch.Tensor,
|
||||
estimated_reference_target_mel_len: List[int],
|
||||
remove_input_padding: bool = False,
|
||||
@@ -354,26 +383,43 @@ class F5TTS(object):
|
||||
if use_perf:
|
||||
torch.cuda.nvtx.range_push("text embedding")
|
||||
batch = text_pad_sequence.shape[0]
|
||||
max_seq_len = ref_mel_batch.shape[1]
|
||||
max_seq_len = cond_pad_sequence.shape[1]
|
||||
|
||||
text_pad_sequence_drop = torch.cat(
|
||||
(text_pad_sequence, torch.zeros((1, text_pad_sequence.shape[1]), dtype=torch.int32).to(self.device)), dim=0
|
||||
# get text_embed one by one to avoid misalignment
|
||||
text_and_drop_embedding_list = []
|
||||
for i in range(batch):
|
||||
text_embedding_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=False,
|
||||
)
|
||||
text_embedding_drop_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=True,
|
||||
)
|
||||
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
|
||||
|
||||
# pad separately computed text_embed to form batch with max_seq_len
|
||||
text_and_drop_embedding = pad_sequence(
|
||||
text_and_drop_embedding_list,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
text_embedding = text_and_drop_embedding[0::2]
|
||||
text_embedding_drop = text_and_drop_embedding[1::2]
|
||||
|
||||
text_embedding_drop_list = []
|
||||
for i in range(batch + 1):
|
||||
text_embedding_drop_list.append(self.text_embedding(text_pad_sequence_drop[i].unsqueeze(0).to(self.device)))
|
||||
text_embedding_drop_condition = torch.cat(text_embedding_drop_list, dim=0)
|
||||
|
||||
text_embedding = text_embedding_drop_condition[:-1]
|
||||
# text_embedding_drop B,T,C batch should be the same
|
||||
text_embedding_drop = text_embedding_drop_condition[-1].unsqueeze(0).repeat(batch, 1, 1)
|
||||
|
||||
noise = torch.randn_like(ref_mel_batch).to(self.device)
|
||||
noise = torch.randn_like(cond_pad_sequence).to(self.device)
|
||||
rope_cos = self.rope_cos[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
rope_sin = self.rope_sin[:, :max_seq_len, :].float().repeat(batch, 1, 1)
|
||||
|
||||
cat_mel_text = torch.cat((ref_mel_batch, text_embedding), dim=-1)
|
||||
cat_mel_text = torch.cat(
|
||||
(
|
||||
cond_pad_sequence,
|
||||
text_embedding,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
cat_mel_text_drop = torch.cat(
|
||||
(
|
||||
torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device),
|
||||
|
||||
@@ -24,16 +24,16 @@
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
import json
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
import torchaudio
|
||||
import jieba
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
import os
|
||||
|
||||
import jieba
|
||||
import torch
|
||||
import torchaudio
|
||||
import triton_python_backend_utils as pb_utils
|
||||
from f5_tts_trtllm import F5TTS
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
|
||||
|
||||
def get_tokenizer(vocab_file_path: str):
|
||||
@@ -98,7 +98,8 @@ def list_str_to_idx(
|
||||
padding_value=-1,
|
||||
): # noqa: F722
|
||||
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
||||
return list_idx_tensors
|
||||
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||
return text
|
||||
|
||||
|
||||
class TritonPythonModel:
|
||||
@@ -106,13 +107,12 @@ class TritonPythonModel:
|
||||
self.use_perf = True
|
||||
self.device = torch.device("cuda")
|
||||
self.target_audio_sample_rate = 24000
|
||||
self.target_rms = 0.15 # target rms for audio
|
||||
self.target_rms = 0.1 # least rms when inference, normalize to if lower
|
||||
self.n_fft = 1024
|
||||
self.win_length = 1024
|
||||
self.hop_length = 256
|
||||
self.n_mel_channels = 100
|
||||
self.max_mel_len = 3000
|
||||
self.head_dim = 64
|
||||
self.max_mel_len = 4096
|
||||
|
||||
parameters = json.loads(args["model_config"])["parameters"]
|
||||
for key, value in parameters.items():
|
||||
@@ -180,7 +180,8 @@ class TritonPythonModel:
|
||||
reference_target_texts_list,
|
||||
estimated_reference_target_mel_len,
|
||||
reference_mel_len,
|
||||
) = [], [], [], [], []
|
||||
reference_rms_list,
|
||||
) = [], [], [], [], [], []
|
||||
mel_features_list = []
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_push("preprocess")
|
||||
@@ -207,6 +208,7 @@ class TritonPythonModel:
|
||||
ref_rms = torch.sqrt(torch.mean(torch.square(wav)))
|
||||
if ref_rms < self.target_rms:
|
||||
wav = wav * self.target_rms / ref_rms
|
||||
reference_rms_list.append(ref_rms)
|
||||
if self.reference_sample_rate != self.target_audio_sample_rate:
|
||||
wav = self.resampler(wav)
|
||||
wav = wav.to(self.device)
|
||||
@@ -219,13 +221,15 @@ class TritonPythonModel:
|
||||
|
||||
reference_mel_len.append(mel_features.shape[1])
|
||||
estimated_reference_target_mel_len.append(
|
||||
int(mel_features.shape[1] * (1 + len(target_text) / len(reference_text)))
|
||||
int(
|
||||
mel_features.shape[1] * (1 + len(target_text.encode("utf-8")) / len(reference_text.encode("utf-8")))
|
||||
)
|
||||
)
|
||||
|
||||
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
||||
|
||||
batch = len(requests)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
|
||||
for i, mel in enumerate(mel_features_list):
|
||||
mel_features[i, : mel.shape[1], :] = mel
|
||||
|
||||
@@ -234,15 +238,6 @@ class TritonPythonModel:
|
||||
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True)
|
||||
text_pad_sequence = list_str_to_idx(pinyin_list, self.vocab_char_map)
|
||||
|
||||
for i, item in enumerate(text_pad_sequence):
|
||||
text_pad_sequence[i] = F.pad(
|
||||
item, (0, estimated_reference_target_mel_len[i] - len(item)), mode="constant", value=-1
|
||||
)
|
||||
text_pad_sequence[i] += 1 # WAR: 0 is reserved for padding token, hard coding in F5-TTS
|
||||
text_pad_sequence = pad_sequence(text_pad_sequence, padding_value=-1, batch_first=True).to(self.device)
|
||||
text_pad_sequence = F.pad(
|
||||
text_pad_sequence, (0, max_seq_len - text_pad_sequence.shape[1]), mode="constant", value=-1
|
||||
)
|
||||
if self.use_perf:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
@@ -259,13 +254,12 @@ class TritonPythonModel:
|
||||
|
||||
responses = []
|
||||
for i in range(batch):
|
||||
ref_me_len = reference_mel_len[i]
|
||||
ref_mel_len = reference_mel_len[i]
|
||||
estimated_mel_len = estimated_reference_target_mel_len[i]
|
||||
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
audio = self.forward_vocoder(denoised_one_item)
|
||||
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
||||
if rms < self.target_rms:
|
||||
audio = audio * self.target_rms / rms
|
||||
if reference_rms_list[i] < self.target_rms:
|
||||
audio = audio * reference_rms_list[i] / self.target_rms
|
||||
|
||||
audio = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
|
||||
inference_response = pb_utils.InferenceResponse(output_tensors=[audio])
|
||||
|
||||
@@ -33,7 +33,7 @@ parameters [
|
||||
},
|
||||
{
|
||||
key: "reference_audio_sample_rate",
|
||||
value: {string_value:"16000"}
|
||||
value: {string_value:"24000"}
|
||||
},
|
||||
{
|
||||
key: "vocoder",
|
||||
|
||||
@@ -34,6 +34,7 @@ from .deepseek_v2.model import DeepseekV2ForCausalLM
|
||||
from .dit.model import DiT
|
||||
from .eagle.model import EagleForCausalLM
|
||||
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
|
||||
from .f5tts.model import F5TTS
|
||||
from .falcon.config import FalconConfig
|
||||
from .falcon.model import FalconForCausalLM, FalconModel
|
||||
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
|
||||
@@ -54,12 +55,12 @@ from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodi
|
||||
from .mpt.model import MPTForCausalLM, MPTModel
|
||||
from .nemotron_nas.model import DeciLMForCausalLM
|
||||
from .opt.model import OPTForCausalLM, OPTModel
|
||||
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
||||
from .phi.model import PhiForCausalLM, PhiModel
|
||||
from .phi3.model import Phi3ForCausalLM, Phi3Model
|
||||
from .qwen.model import QWenForCausalLM
|
||||
from .recurrentgemma.model import RecurrentGemmaForCausalLM
|
||||
from .redrafter.model import ReDrafterForCausalLM
|
||||
from .f5tts.model import F5TTS
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BertModel",
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import tensorrt as trt
|
||||
from collections import OrderedDict
|
||||
from tensorrt_llm._common import default_net
|
||||
|
||||
from ..._utils import str_dtype_to_trt
|
||||
from ...functional import Tensor, concat
|
||||
from ...layers import Linear
|
||||
from ...module import Module, ModuleList
|
||||
from ...plugin import current_all_reduce_helper
|
||||
from ..modeling_utils import PretrainedConfig, PretrainedModel
|
||||
from ...functional import Tensor, concat
|
||||
from ...module import Module, ModuleList
|
||||
from tensorrt_llm._common import default_net
|
||||
from ...layers import Linear
|
||||
from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding
|
||||
|
||||
from .modules import (
|
||||
TimestepEmbedding,
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
)
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
parent_dir = os.path.dirname(current_file_path)
|
||||
@@ -53,6 +50,7 @@ class F5TTS(PretrainedModel):
|
||||
dim_head=config.dim_head,
|
||||
ff_mult=config.ff_mult,
|
||||
dropout=config.dropout,
|
||||
pe_attn_head=config.pe_attn_head,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
@@ -82,13 +80,12 @@ class F5TTS(PretrainedModel):
|
||||
def prepare_inputs(self, **kwargs):
|
||||
max_batch_size = kwargs["max_batch_size"]
|
||||
batch_size_range = [2, 2, max_batch_size]
|
||||
mel_size = 100
|
||||
max_seq_len = 3000
|
||||
num_frames_range = [200, 2 * max_seq_len, max_seq_len * max_batch_size]
|
||||
hidden_size = 512
|
||||
concat_feature_dim = mel_size + hidden_size
|
||||
freq_embed_dim = 256
|
||||
head_dim = 64
|
||||
mel_size = self.config.mel_dim
|
||||
max_seq_len = 3000 # 4096
|
||||
num_frames_range = [mel_size * 2, max_seq_len * 2, max_seq_len * max_batch_size]
|
||||
concat_feature_dim = mel_size + self.config.text_dim
|
||||
freq_embed_dim = 256 # Warning: hard coding 256 here
|
||||
head_dim = self.config.dim_head
|
||||
mapping = self.config.mapping
|
||||
if mapping.tp_size > 1:
|
||||
current_all_reduce_helper().set_workspace_tensor(mapping, 1)
|
||||
|
||||
@@ -3,33 +3,35 @@ from __future__ import annotations
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
from tensorrt_llm._common import default_net
|
||||
from ..._utils import trt_dtype_to_np, str_dtype_to_trt
|
||||
|
||||
from ..._utils import str_dtype_to_trt, trt_dtype_to_np
|
||||
from ...functional import (
|
||||
Tensor,
|
||||
bert_attention,
|
||||
cast,
|
||||
chunk,
|
||||
concat,
|
||||
constant,
|
||||
expand,
|
||||
expand_dims,
|
||||
expand_dims_like,
|
||||
expand_mask,
|
||||
gelu,
|
||||
matmul,
|
||||
permute,
|
||||
shape,
|
||||
silu,
|
||||
slice,
|
||||
permute,
|
||||
expand_mask,
|
||||
expand_dims_like,
|
||||
unsqueeze,
|
||||
matmul,
|
||||
softmax,
|
||||
squeeze,
|
||||
cast,
|
||||
gelu,
|
||||
unsqueeze,
|
||||
view,
|
||||
)
|
||||
from ...functional import expand_dims, view, bert_attention
|
||||
from ...layers import LayerNorm, Linear, Conv1d, Mish, RowLinear, ColumnLinear
|
||||
from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear
|
||||
from ...module import Module
|
||||
|
||||
|
||||
@@ -225,29 +227,52 @@ def rotate_every_two_3dim(tensor: Tensor) -> Tensor:
|
||||
return out
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin):
|
||||
if default_net().plugin_config.remove_input_padding:
|
||||
rot_dim = shape(rope_cos, -1) # 64
|
||||
new_t_shape = concat([shape(x, 0), rot_dim]) # (-1, 64)
|
||||
x_ = slice(x, [0, 0], new_t_shape, [1, 1])
|
||||
end_dim = shape(x, -1) - shape(rope_cos, -1)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, rot_dim]), new_t_unrotated_shape, [1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
else:
|
||||
rot_dim = shape(rope_cos, 2) # 64
|
||||
new_t_shape = concat([shape(x, 0), shape(x, 1), rot_dim]) # (2, -1, 64)
|
||||
x_ = slice(x, [0, 0, 0], new_t_shape, [1, 1, 1])
|
||||
end_dim = shape(x, 2) - shape(rope_cos, 2)
|
||||
new_t_unrotated_shape = concat([shape(x, 0), shape(x, 1), end_dim]) # (2, -1, 960)
|
||||
x_unrotated = slice(x, concat([0, 0, rot_dim]), new_t_unrotated_shape, [1, 1, 1])
|
||||
out = concat([x_ * rope_cos + rotate_every_two_3dim(x_) * rope_sin, x_unrotated], dim=-1)
|
||||
def apply_rotary_pos_emb_3dim(x, rope_cos, rope_sin, pe_attn_head):
|
||||
full_dim = x.size(-1)
|
||||
head_dim = rope_cos.size(-1) # attn head dim, e.g. 64
|
||||
if pe_attn_head is None:
|
||||
pe_attn_head = full_dim // head_dim
|
||||
rotated_dim = head_dim * pe_attn_head
|
||||
|
||||
rotated_and_unrotated_list = []
|
||||
|
||||
if default_net().plugin_config.remove_input_padding: # for [N, D] input
|
||||
new_t_shape = concat([shape(x, 0), head_dim]) # (2, -1, 64)
|
||||
|
||||
for i in range(pe_attn_head):
|
||||
x_slice_i = slice(x, [0, i * 64], new_t_shape, [1, 1])
|
||||
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
|
||||
rotated_and_unrotated_list.append(x_rotated_i)
|
||||
|
||||
new_t_unrotated_shape = concat([shape(x, 0), full_dim - rotated_dim]) # (2, -1, 1024 - 64 * pe_attn_head)
|
||||
x_unrotated = slice(x, concat([0, rotated_dim]), new_t_unrotated_shape, [1, 1])
|
||||
rotated_and_unrotated_list.append(x_unrotated)
|
||||
|
||||
else: # for [B, N, D] input
|
||||
new_t_shape = concat([shape(x, 0), shape(x, 1), head_dim]) # (2, -1, 64)
|
||||
|
||||
for i in range(pe_attn_head):
|
||||
x_slice_i = slice(x, [0, 0, i * 64], new_t_shape, [1, 1, 1])
|
||||
x_rotated_i = x_slice_i * rope_cos + rotate_every_two_3dim(x_slice_i) * rope_sin
|
||||
rotated_and_unrotated_list.append(x_rotated_i)
|
||||
|
||||
new_t_unrotated_shape = concat(
|
||||
[shape(x, 0), shape(x, 1), full_dim - rotated_dim]
|
||||
) # (2, -1, 1024 - 64 * pe_attn_head)
|
||||
x_unrotated = slice(x, concat([0, 0, rotated_dim]), new_t_unrotated_shape, [1, 1, 1])
|
||||
rotated_and_unrotated_list.append(x_unrotated)
|
||||
|
||||
out = concat(rotated_and_unrotated_list, dim=-1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: Optional[int] = None, # number of attention head to apply rope, None for all
|
||||
):
|
||||
self.pe_attn_head = pe_attn_head
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -263,8 +288,8 @@ class AttnProcessor:
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
# k,v,q all (2,1226,1024)
|
||||
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin)
|
||||
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin)
|
||||
query = apply_rotary_pos_emb_3dim(query, rope_cos, rope_sin, self.pe_attn_head)
|
||||
key = apply_rotary_pos_emb_3dim(key, rope_cos, rope_sin, self.pe_attn_head)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
@@ -352,12 +377,12 @@ class AttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
class DiTBlock(Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
|
||||
@@ -1,64 +1,66 @@
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
model=$3 # F5TTS_Base
|
||||
model=$3 # F5TTS_v1_Base | F5TTS_Base | F5TTS_v1_Small | F5TTS_Small
|
||||
if [ -z "$model" ]; then
|
||||
echo "Model is none"
|
||||
exit 1
|
||||
model=F5TTS_v1_Base
|
||||
fi
|
||||
echo "Start stage: $stage, Stop stage: $stop_stage, Model: $model"
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
F5_TTS_HF_DOWNLOAD_PATH=./F5-TTS
|
||||
F5_TTS_TRT_LLM_CHECKPOINT_PATH=./trtllm_ckpt
|
||||
F5_TTS_TRT_LLM_ENGINE_PATH=./f5_trt_llm_engine
|
||||
CKPT_DIR=../../../../ckpts
|
||||
TRTLLM_CKPT_DIR=$CKPT_DIR/$model/trtllm_ckpt
|
||||
TRTLLM_ENGINE_DIR=$CKPT_DIR/$model/trtllm_engine
|
||||
|
||||
vocoder_trt_engine_path=vocos_vocoder.plan
|
||||
model_repo=./model_repo
|
||||
VOCODER_ONNX_PATH=$CKPT_DIR/vocos_vocoder.onnx
|
||||
VOCODER_TRT_ENGINE_PATH=$CKPT_DIR/vocos_vocoder.plan
|
||||
MODEL_REPO=./model_repo
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
echo "Downloading f5 tts from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS --local-dir $F5_TTS_HF_DOWNLOAD_PATH
|
||||
|
||||
echo "Downloading F5-TTS from huggingface"
|
||||
huggingface-cli download SWivid/F5-TTS $model/model_*.* $model/vocab.txt --local-dir $CKPT_DIR
|
||||
fi
|
||||
|
||||
ckpt_file=$(ls $CKPT_DIR/$model/model_*.* 2>/dev/null | sort -V | tail -1) # default select latest update
|
||||
vocab_file=$CKPT_DIR/$model/vocab.txt
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
echo "Converting checkpoint"
|
||||
python3 ./scripts/convert_checkpoint.py \
|
||||
--timm_ckpt "$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt" \
|
||||
--output_dir "$F5_TTS_TRT_LLM_CHECKPOINT_PATH" --model_name $model
|
||||
python3 scripts/convert_checkpoint.py \
|
||||
--pytorch_ckpt $ckpt_file \
|
||||
--output_dir $TRTLLM_CKPT_DIR --model_name $model
|
||||
python_package_path=/usr/local/lib/python3.12/dist-packages
|
||||
cp -r patch/* $python_package_path/tensorrt_llm/models
|
||||
trtllm-build --checkpoint_dir $F5_TTS_TRT_LLM_CHECKPOINT_PATH \
|
||||
trtllm-build --checkpoint_dir $TRTLLM_CKPT_DIR \
|
||||
--max_batch_size 8 \
|
||||
--output_dir $F5_TTS_TRT_LLM_ENGINE_PATH --remove_input_padding disable
|
||||
--output_dir $TRTLLM_ENGINE_DIR --remove_input_padding disable
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
echo "Exporting vocos vocoder"
|
||||
onnx_vocoder_path=vocos_vocoder.onnx
|
||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $onnx_vocoder_path
|
||||
bash scripts/export_vocos_trt.sh $onnx_vocoder_path $vocoder_trt_engine_path
|
||||
python3 scripts/export_vocoder_to_onnx.py --vocoder vocos --output-path $VOCODER_ONNX_PATH
|
||||
bash scripts/export_vocos_trt.sh $VOCODER_ONNX_PATH $VOCODER_TRT_ENGINE_PATH
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
echo "Building triton server"
|
||||
rm -r $model_repo
|
||||
cp -r ./model_repo_f5_tts $model_repo
|
||||
python3 scripts/fill_template.py -i $model_repo/f5_tts/config.pbtxt vocab:$F5_TTS_HF_DOWNLOAD_PATH/$model/vocab.txt,model:$F5_TTS_HF_DOWNLOAD_PATH/$model/model_1200000.pt,trtllm:$F5_TTS_TRT_LLM_ENGINE_PATH,vocoder:vocos
|
||||
cp $vocoder_trt_engine_path $model_repo/vocoder/1/vocoder.plan
|
||||
rm -r $MODEL_REPO
|
||||
cp -r ./model_repo_f5_tts $MODEL_REPO
|
||||
python3 scripts/fill_template.py -i $MODEL_REPO/f5_tts/config.pbtxt vocab:$vocab_file,model:$ckpt_file,trtllm:$TRTLLM_ENGINE_DIR,vocoder:vocos
|
||||
cp $VOCODER_TRT_ENGINE_PATH $MODEL_REPO/vocoder/1/vocoder.plan
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
echo "Starting triton server"
|
||||
tritonserver --model-repository=$model_repo
|
||||
tritonserver --model-repository=$MODEL_REPO
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
echo "Testing triton server"
|
||||
num_task=1
|
||||
log_dir=./log_concurrent_tasks_${num_task}
|
||||
split_name=wenetspeech4tts
|
||||
log_dir=./tests/client_grpc_${model}_concurrent_${num_task}_${split_name}
|
||||
rm -r $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts --log-dir $log_dir
|
||||
python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name $split_name --log-dir $log_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
@@ -66,5 +68,45 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
audio=../../infer/examples/basic/basic_ref_en.wav
|
||||
reference_text="Some call me nature, others call me mother nature."
|
||||
target_text="I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text"
|
||||
python3 client_http.py --reference-audio $audio --reference-text "$reference_text" --target-text "$target_text" --output-audio "./tests/client_http_$model.wav"
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "TRT-LLM: offline decoding benchmark test"
|
||||
batch_size=1
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=trt
|
||||
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--enable-warmup \
|
||||
--split-name $split_name \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--vocoder-trt-engine-path $VOCODER_TRT_ENGINE_PATH \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
echo "Native Pytorch: offline decoding benchmark test"
|
||||
if ! python3 -c "import f5_tts" &> /dev/null; then
|
||||
pip install -e ../../../../
|
||||
fi
|
||||
batch_size=1 # set attn_mask_enabled=True if batching in actual use case
|
||||
split_name=wenetspeech4tts
|
||||
backend_type=pytorch
|
||||
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}
|
||||
rm -r $log_dir
|
||||
torchrun --nproc_per_node=1 \
|
||||
benchmark.py --output-dir $log_dir \
|
||||
--batch-size $batch_size \
|
||||
--split-name $split_name \
|
||||
--enable-warmup \
|
||||
--model-path $ckpt_file \
|
||||
--vocab-file $vocab_file \
|
||||
--backend-type $backend_type \
|
||||
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1
|
||||
fi
|
||||
@@ -40,6 +40,7 @@ import torch as th
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import check_COLA, get_window
|
||||
|
||||
|
||||
support_clp_op = None
|
||||
if th.__version__ >= "1.7.0":
|
||||
from torch.fft import rfft as fft
|
||||
|
||||
@@ -8,7 +8,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import str_dtype_to_torch
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.convert_utils import split, split_matrix_tp
|
||||
@@ -24,168 +23,12 @@ def split_q_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
|
||||
return split_v.contiguous()
|
||||
|
||||
|
||||
FACEBOOK_DIT_NAME_MAPPING = {
|
||||
"^time_embed.time_mlp.0.weight$": "time_embed.mlp1.weight",
|
||||
"^time_embed.time_mlp.0.bias$": "time_embed.mlp1.bias",
|
||||
"^time_embed.time_mlp.2.weight$": "time_embed.mlp2.weight",
|
||||
"^time_embed.time_mlp.2.bias$": "time_embed.mlp2.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.weight$": "input_embed.conv_pos_embed.conv1d1.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.0.bias$": "input_embed.conv_pos_embed.conv1d1.bias",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.weight$": "input_embed.conv_pos_embed.conv1d2.weight",
|
||||
"^input_embed.conv_pos_embed.conv1d.2.bias$": "input_embed.conv_pos_embed.conv1d2.bias",
|
||||
"^transformer_blocks.0.attn.to_out.0.weight$": "transformer_blocks.0.attn.to_out.weight",
|
||||
"^transformer_blocks.0.attn.to_out.0.bias$": "transformer_blocks.0.attn.to_out.bias",
|
||||
"^transformer_blocks.1.attn.to_out.0.weight$": "transformer_blocks.1.attn.to_out.weight",
|
||||
"^transformer_blocks.1.attn.to_out.0.bias$": "transformer_blocks.1.attn.to_out.bias",
|
||||
"^transformer_blocks.2.attn.to_out.0.weight$": "transformer_blocks.2.attn.to_out.weight",
|
||||
"^transformer_blocks.2.attn.to_out.0.bias$": "transformer_blocks.2.attn.to_out.bias",
|
||||
"^transformer_blocks.3.attn.to_out.0.weight$": "transformer_blocks.3.attn.to_out.weight",
|
||||
"^transformer_blocks.3.attn.to_out.0.bias$": "transformer_blocks.3.attn.to_out.bias",
|
||||
"^transformer_blocks.4.attn.to_out.0.weight$": "transformer_blocks.4.attn.to_out.weight",
|
||||
"^transformer_blocks.4.attn.to_out.0.bias$": "transformer_blocks.4.attn.to_out.bias",
|
||||
"^transformer_blocks.5.attn.to_out.0.weight$": "transformer_blocks.5.attn.to_out.weight",
|
||||
"^transformer_blocks.5.attn.to_out.0.bias$": "transformer_blocks.5.attn.to_out.bias",
|
||||
"^transformer_blocks.6.attn.to_out.0.weight$": "transformer_blocks.6.attn.to_out.weight",
|
||||
"^transformer_blocks.6.attn.to_out.0.bias$": "transformer_blocks.6.attn.to_out.bias",
|
||||
"^transformer_blocks.7.attn.to_out.0.weight$": "transformer_blocks.7.attn.to_out.weight",
|
||||
"^transformer_blocks.7.attn.to_out.0.bias$": "transformer_blocks.7.attn.to_out.bias",
|
||||
"^transformer_blocks.8.attn.to_out.0.weight$": "transformer_blocks.8.attn.to_out.weight",
|
||||
"^transformer_blocks.8.attn.to_out.0.bias$": "transformer_blocks.8.attn.to_out.bias",
|
||||
"^transformer_blocks.9.attn.to_out.0.weight$": "transformer_blocks.9.attn.to_out.weight",
|
||||
"^transformer_blocks.9.attn.to_out.0.bias$": "transformer_blocks.9.attn.to_out.bias",
|
||||
"^transformer_blocks.10.attn.to_out.0.weight$": "transformer_blocks.10.attn.to_out.weight",
|
||||
"^transformer_blocks.10.attn.to_out.0.bias$": "transformer_blocks.10.attn.to_out.bias",
|
||||
"^transformer_blocks.11.attn.to_out.0.weight$": "transformer_blocks.11.attn.to_out.weight",
|
||||
"^transformer_blocks.11.attn.to_out.0.bias$": "transformer_blocks.11.attn.to_out.bias",
|
||||
"^transformer_blocks.12.attn.to_out.0.weight$": "transformer_blocks.12.attn.to_out.weight",
|
||||
"^transformer_blocks.12.attn.to_out.0.bias$": "transformer_blocks.12.attn.to_out.bias",
|
||||
"^transformer_blocks.13.attn.to_out.0.weight$": "transformer_blocks.13.attn.to_out.weight",
|
||||
"^transformer_blocks.13.attn.to_out.0.bias$": "transformer_blocks.13.attn.to_out.bias",
|
||||
"^transformer_blocks.14.attn.to_out.0.weight$": "transformer_blocks.14.attn.to_out.weight",
|
||||
"^transformer_blocks.14.attn.to_out.0.bias$": "transformer_blocks.14.attn.to_out.bias",
|
||||
"^transformer_blocks.15.attn.to_out.0.weight$": "transformer_blocks.15.attn.to_out.weight",
|
||||
"^transformer_blocks.15.attn.to_out.0.bias$": "transformer_blocks.15.attn.to_out.bias",
|
||||
"^transformer_blocks.16.attn.to_out.0.weight$": "transformer_blocks.16.attn.to_out.weight",
|
||||
"^transformer_blocks.16.attn.to_out.0.bias$": "transformer_blocks.16.attn.to_out.bias",
|
||||
"^transformer_blocks.17.attn.to_out.0.weight$": "transformer_blocks.17.attn.to_out.weight",
|
||||
"^transformer_blocks.17.attn.to_out.0.bias$": "transformer_blocks.17.attn.to_out.bias",
|
||||
"^transformer_blocks.18.attn.to_out.0.weight$": "transformer_blocks.18.attn.to_out.weight",
|
||||
"^transformer_blocks.18.attn.to_out.0.bias$": "transformer_blocks.18.attn.to_out.bias",
|
||||
"^transformer_blocks.19.attn.to_out.0.weight$": "transformer_blocks.19.attn.to_out.weight",
|
||||
"^transformer_blocks.19.attn.to_out.0.bias$": "transformer_blocks.19.attn.to_out.bias",
|
||||
"^transformer_blocks.20.attn.to_out.0.weight$": "transformer_blocks.20.attn.to_out.weight",
|
||||
"^transformer_blocks.20.attn.to_out.0.bias$": "transformer_blocks.20.attn.to_out.bias",
|
||||
"^transformer_blocks.21.attn.to_out.0.weight$": "transformer_blocks.21.attn.to_out.weight",
|
||||
"^transformer_blocks.21.attn.to_out.0.bias$": "transformer_blocks.21.attn.to_out.bias",
|
||||
"^transformer_blocks.0.ff.ff.0.0.weight$": "transformer_blocks.0.ff.project_in.weight",
|
||||
"^transformer_blocks.0.ff.ff.0.0.bias$": "transformer_blocks.0.ff.project_in.bias",
|
||||
"^transformer_blocks.0.ff.ff.2.weight$": "transformer_blocks.0.ff.ff.weight",
|
||||
"^transformer_blocks.0.ff.ff.2.bias$": "transformer_blocks.0.ff.ff.bias",
|
||||
"^transformer_blocks.1.ff.ff.0.0.weight$": "transformer_blocks.1.ff.project_in.weight",
|
||||
"^transformer_blocks.1.ff.ff.0.0.bias$": "transformer_blocks.1.ff.project_in.bias",
|
||||
"^transformer_blocks.1.ff.ff.2.weight$": "transformer_blocks.1.ff.ff.weight",
|
||||
"^transformer_blocks.1.ff.ff.2.bias$": "transformer_blocks.1.ff.ff.bias",
|
||||
"^transformer_blocks.2.ff.ff.0.0.weight$": "transformer_blocks.2.ff.project_in.weight",
|
||||
"^transformer_blocks.2.ff.ff.0.0.bias$": "transformer_blocks.2.ff.project_in.bias",
|
||||
"^transformer_blocks.2.ff.ff.2.weight$": "transformer_blocks.2.ff.ff.weight",
|
||||
"^transformer_blocks.2.ff.ff.2.bias$": "transformer_blocks.2.ff.ff.bias",
|
||||
"^transformer_blocks.3.ff.ff.0.0.weight$": "transformer_blocks.3.ff.project_in.weight",
|
||||
"^transformer_blocks.3.ff.ff.0.0.bias$": "transformer_blocks.3.ff.project_in.bias",
|
||||
"^transformer_blocks.3.ff.ff.2.weight$": "transformer_blocks.3.ff.ff.weight",
|
||||
"^transformer_blocks.3.ff.ff.2.bias$": "transformer_blocks.3.ff.ff.bias",
|
||||
"^transformer_blocks.4.ff.ff.0.0.weight$": "transformer_blocks.4.ff.project_in.weight",
|
||||
"^transformer_blocks.4.ff.ff.0.0.bias$": "transformer_blocks.4.ff.project_in.bias",
|
||||
"^transformer_blocks.4.ff.ff.2.weight$": "transformer_blocks.4.ff.ff.weight",
|
||||
"^transformer_blocks.4.ff.ff.2.bias$": "transformer_blocks.4.ff.ff.bias",
|
||||
"^transformer_blocks.5.ff.ff.0.0.weight$": "transformer_blocks.5.ff.project_in.weight",
|
||||
"^transformer_blocks.5.ff.ff.0.0.bias$": "transformer_blocks.5.ff.project_in.bias",
|
||||
"^transformer_blocks.5.ff.ff.2.weight$": "transformer_blocks.5.ff.ff.weight",
|
||||
"^transformer_blocks.5.ff.ff.2.bias$": "transformer_blocks.5.ff.ff.bias",
|
||||
"^transformer_blocks.6.ff.ff.0.0.weight$": "transformer_blocks.6.ff.project_in.weight",
|
||||
"^transformer_blocks.6.ff.ff.0.0.bias$": "transformer_blocks.6.ff.project_in.bias",
|
||||
"^transformer_blocks.6.ff.ff.2.weight$": "transformer_blocks.6.ff.ff.weight",
|
||||
"^transformer_blocks.6.ff.ff.2.bias$": "transformer_blocks.6.ff.ff.bias",
|
||||
"^transformer_blocks.7.ff.ff.0.0.weight$": "transformer_blocks.7.ff.project_in.weight",
|
||||
"^transformer_blocks.7.ff.ff.0.0.bias$": "transformer_blocks.7.ff.project_in.bias",
|
||||
"^transformer_blocks.7.ff.ff.2.weight$": "transformer_blocks.7.ff.ff.weight",
|
||||
"^transformer_blocks.7.ff.ff.2.bias$": "transformer_blocks.7.ff.ff.bias",
|
||||
"^transformer_blocks.8.ff.ff.0.0.weight$": "transformer_blocks.8.ff.project_in.weight",
|
||||
"^transformer_blocks.8.ff.ff.0.0.bias$": "transformer_blocks.8.ff.project_in.bias",
|
||||
"^transformer_blocks.8.ff.ff.2.weight$": "transformer_blocks.8.ff.ff.weight",
|
||||
"^transformer_blocks.8.ff.ff.2.bias$": "transformer_blocks.8.ff.ff.bias",
|
||||
"^transformer_blocks.9.ff.ff.0.0.weight$": "transformer_blocks.9.ff.project_in.weight",
|
||||
"^transformer_blocks.9.ff.ff.0.0.bias$": "transformer_blocks.9.ff.project_in.bias",
|
||||
"^transformer_blocks.9.ff.ff.2.weight$": "transformer_blocks.9.ff.ff.weight",
|
||||
"^transformer_blocks.9.ff.ff.2.bias$": "transformer_blocks.9.ff.ff.bias",
|
||||
"^transformer_blocks.10.ff.ff.0.0.weight$": "transformer_blocks.10.ff.project_in.weight",
|
||||
"^transformer_blocks.10.ff.ff.0.0.bias$": "transformer_blocks.10.ff.project_in.bias",
|
||||
"^transformer_blocks.10.ff.ff.2.weight$": "transformer_blocks.10.ff.ff.weight",
|
||||
"^transformer_blocks.10.ff.ff.2.bias$": "transformer_blocks.10.ff.ff.bias",
|
||||
"^transformer_blocks.11.ff.ff.0.0.weight$": "transformer_blocks.11.ff.project_in.weight",
|
||||
"^transformer_blocks.11.ff.ff.0.0.bias$": "transformer_blocks.11.ff.project_in.bias",
|
||||
"^transformer_blocks.11.ff.ff.2.weight$": "transformer_blocks.11.ff.ff.weight",
|
||||
"^transformer_blocks.11.ff.ff.2.bias$": "transformer_blocks.11.ff.ff.bias",
|
||||
"^transformer_blocks.12.ff.ff.0.0.weight$": "transformer_blocks.12.ff.project_in.weight",
|
||||
"^transformer_blocks.12.ff.ff.0.0.bias$": "transformer_blocks.12.ff.project_in.bias",
|
||||
"^transformer_blocks.12.ff.ff.2.weight$": "transformer_blocks.12.ff.ff.weight",
|
||||
"^transformer_blocks.12.ff.ff.2.bias$": "transformer_blocks.12.ff.ff.bias",
|
||||
"^transformer_blocks.13.ff.ff.0.0.weight$": "transformer_blocks.13.ff.project_in.weight",
|
||||
"^transformer_blocks.13.ff.ff.0.0.bias$": "transformer_blocks.13.ff.project_in.bias",
|
||||
"^transformer_blocks.13.ff.ff.2.weight$": "transformer_blocks.13.ff.ff.weight",
|
||||
"^transformer_blocks.13.ff.ff.2.bias$": "transformer_blocks.13.ff.ff.bias",
|
||||
"^transformer_blocks.14.ff.ff.0.0.weight$": "transformer_blocks.14.ff.project_in.weight",
|
||||
"^transformer_blocks.14.ff.ff.0.0.bias$": "transformer_blocks.14.ff.project_in.bias",
|
||||
"^transformer_blocks.14.ff.ff.2.weight$": "transformer_blocks.14.ff.ff.weight",
|
||||
"^transformer_blocks.14.ff.ff.2.bias$": "transformer_blocks.14.ff.ff.bias",
|
||||
"^transformer_blocks.15.ff.ff.0.0.weight$": "transformer_blocks.15.ff.project_in.weight",
|
||||
"^transformer_blocks.15.ff.ff.0.0.bias$": "transformer_blocks.15.ff.project_in.bias",
|
||||
"^transformer_blocks.15.ff.ff.2.weight$": "transformer_blocks.15.ff.ff.weight",
|
||||
"^transformer_blocks.15.ff.ff.2.bias$": "transformer_blocks.15.ff.ff.bias",
|
||||
"^transformer_blocks.16.ff.ff.0.0.weight$": "transformer_blocks.16.ff.project_in.weight",
|
||||
"^transformer_blocks.16.ff.ff.0.0.bias$": "transformer_blocks.16.ff.project_in.bias",
|
||||
"^transformer_blocks.16.ff.ff.2.weight$": "transformer_blocks.16.ff.ff.weight",
|
||||
"^transformer_blocks.16.ff.ff.2.bias$": "transformer_blocks.16.ff.ff.bias",
|
||||
"^transformer_blocks.17.ff.ff.0.0.weight$": "transformer_blocks.17.ff.project_in.weight",
|
||||
"^transformer_blocks.17.ff.ff.0.0.bias$": "transformer_blocks.17.ff.project_in.bias",
|
||||
"^transformer_blocks.17.ff.ff.2.weight$": "transformer_blocks.17.ff.ff.weight",
|
||||
"^transformer_blocks.17.ff.ff.2.bias$": "transformer_blocks.17.ff.ff.bias",
|
||||
"^transformer_blocks.18.ff.ff.0.0.weight$": "transformer_blocks.18.ff.project_in.weight",
|
||||
"^transformer_blocks.18.ff.ff.0.0.bias$": "transformer_blocks.18.ff.project_in.bias",
|
||||
"^transformer_blocks.18.ff.ff.2.weight$": "transformer_blocks.18.ff.ff.weight",
|
||||
"^transformer_blocks.18.ff.ff.2.bias$": "transformer_blocks.18.ff.ff.bias",
|
||||
"^transformer_blocks.19.ff.ff.0.0.weight$": "transformer_blocks.19.ff.project_in.weight",
|
||||
"^transformer_blocks.19.ff.ff.0.0.bias$": "transformer_blocks.19.ff.project_in.bias",
|
||||
"^transformer_blocks.19.ff.ff.2.weight$": "transformer_blocks.19.ff.ff.weight",
|
||||
"^transformer_blocks.19.ff.ff.2.bias$": "transformer_blocks.19.ff.ff.bias",
|
||||
"^transformer_blocks.20.ff.ff.0.0.weight$": "transformer_blocks.20.ff.project_in.weight",
|
||||
"^transformer_blocks.20.ff.ff.0.0.bias$": "transformer_blocks.20.ff.project_in.bias",
|
||||
"^transformer_blocks.20.ff.ff.2.weight$": "transformer_blocks.20.ff.ff.weight",
|
||||
"^transformer_blocks.20.ff.ff.2.bias$": "transformer_blocks.20.ff.ff.bias",
|
||||
"^transformer_blocks.21.ff.ff.0.0.weight$": "transformer_blocks.21.ff.project_in.weight",
|
||||
"^transformer_blocks.21.ff.ff.0.0.bias$": "transformer_blocks.21.ff.project_in.bias",
|
||||
"^transformer_blocks.21.ff.ff.2.weight$": "transformer_blocks.21.ff.ff.weight",
|
||||
"^transformer_blocks.21.ff.ff.2.bias$": "transformer_blocks.21.ff.ff.bias",
|
||||
}
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="F5TTS_Base",
|
||||
choices=[
|
||||
"F5TTS_Base",
|
||||
],
|
||||
) # TODO: support F5TTS_v1_Base
|
||||
parser.add_argument("--timm_ckpt", type=str, default="./ckpts/model_1200000.pt")
|
||||
parser.add_argument("--pytorch_ckpt", type=str, default="./ckpts/model_last.pt")
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="./tllm_checkpoint", help="The path to save the TensorRT-LLM checkpoint"
|
||||
)
|
||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||
parser.add_argument("--cfg_scale", type=float, default=4.0)
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="N-way tensor parallelism size")
|
||||
parser.add_argument("--cp_size", type=int, default=1, help="Context parallelism size")
|
||||
parser.add_argument("--pp_size", type=int, default=1, help="N-way pipeline parallelism size")
|
||||
@@ -194,33 +37,119 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=1, help="The number of workers for converting checkpoint in parallel"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="F5TTS_Custom",
|
||||
choices=[
|
||||
"F5TTS_v1_Base",
|
||||
"F5TTS_Base",
|
||||
"F5TTS_v1_Small",
|
||||
"F5TTS_Small",
|
||||
], # if set, overwrite the below hyperparams
|
||||
)
|
||||
parser.add_argument("--hidden_size", type=int, default=1024, help="The hidden size of DiT")
|
||||
parser.add_argument("--depth", type=int, default=22, help="The number of DiTBlock layers")
|
||||
parser.add_argument("--num_heads", type=int, default=16, help="The number of heads of attention module")
|
||||
parser.add_argument("--dim_head", type=int, default=64, help="The dimension of attention head")
|
||||
parser.add_argument("--ff_mult", type=int, default=2, help="The FFN intermediate dimension multiplier")
|
||||
parser.add_argument("--text_dim", type=int, default=512, help="The output dimension of text encoder")
|
||||
parser.add_argument(
|
||||
"--text_mask_padding",
|
||||
type=lambda x: x.lower() == "true",
|
||||
choices=[True, False],
|
||||
default=True,
|
||||
help="Whether apply padding mask for conv layers in text encoder",
|
||||
)
|
||||
parser.add_argument("--conv_layers", type=int, default=4, help="The number of conv layers of text encoder")
|
||||
parser.add_argument("--pe_attn_head", type=int, default=None, help="The number of attn head that apply pos emb")
|
||||
args = parser.parse_args()
|
||||
|
||||
# overwrite if --model_name ordered
|
||||
if args.model_name == "F5TTS_v1_Base":
|
||||
args.hidden_size = 1024
|
||||
args.depth = 22
|
||||
args.num_heads = 16
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = True
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = None
|
||||
elif args.model_name == "F5TTS_Base":
|
||||
args.hidden_size = 1024
|
||||
args.depth = 22
|
||||
args.num_heads = 16
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = False
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = 1
|
||||
elif args.model_name == "F5TTS_v1_Small":
|
||||
args.hidden_size = 768
|
||||
args.depth = 18
|
||||
args.num_heads = 12
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = True
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = None
|
||||
elif args.model_name == "F5TTS_Small":
|
||||
args.hidden_size = 768
|
||||
args.depth = 18
|
||||
args.num_heads = 12
|
||||
args.dim_head = 64
|
||||
args.ff_mult = 2
|
||||
args.text_dim = 512
|
||||
args.text_mask_padding = False
|
||||
args.conv_layers = 4
|
||||
args.pe_attn_head = 1
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def convert_timm_dit(args, mapping, dtype="float32"):
|
||||
def convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype="float32", use_ema=True):
|
||||
weights = {}
|
||||
tik = time.time()
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
tensor_parallel = mapping.tp_size
|
||||
|
||||
model_params = dict(torch.load(args.timm_ckpt))
|
||||
model_params = {
|
||||
k: v for k, v in model_params["ema_model_state_dict"].items() if k.startswith("ema_model.transformer")
|
||||
ckpt_path = args.pytorch_ckpt
|
||||
ckpt_type = ckpt_path.split(".")[-1]
|
||||
if ckpt_type == "safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
model_params = load_file(ckpt_path)
|
||||
else:
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model_params = ckpt["ema_model_state_dict"] if use_ema else ckpt["model_state_dict"]
|
||||
|
||||
prefix = "ema_model.transformer." if use_ema else "transformer."
|
||||
if any(k.startswith(prefix) for k in model_params.keys()):
|
||||
model_params = {
|
||||
key[len(prefix) :] if key.startswith(prefix) else key: value
|
||||
for key, value in model_params.items()
|
||||
if key.startswith(prefix)
|
||||
}
|
||||
|
||||
pytorch_to_trtllm_name = {
|
||||
r"^time_embed\.time_mlp\.0\.(weight|bias)$": r"time_embed.mlp1.\1",
|
||||
r"^time_embed\.time_mlp\.2\.(weight|bias)$": r"time_embed.mlp2.\1",
|
||||
r"^input_embed\.conv_pos_embed\.conv1d\.0\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d1.\1",
|
||||
r"^input_embed\.conv_pos_embed\.conv1d\.2\.(weight|bias)$": r"input_embed.conv_pos_embed.conv1d2.\1",
|
||||
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(weight|bias)$": r"transformer_blocks.\1.attn.to_out.\2",
|
||||
r"^transformer_blocks\.(\d+)\.ff\.ff\.0\.0\.(weight|bias)$": r"transformer_blocks.\1.ff.project_in.\2",
|
||||
r"^transformer_blocks\.(\d+)\.ff\.ff\.2\.(weight|bias)$": r"transformer_blocks.\1.ff.ff.\2",
|
||||
}
|
||||
prefix = "ema_model.transformer."
|
||||
model_params = {key[len(prefix) :] if key.startswith(prefix) else key: value for key, value in model_params.items()}
|
||||
|
||||
timm_to_trtllm_name = FACEBOOK_DIT_NAME_MAPPING
|
||||
|
||||
def get_trtllm_name(timm_name):
|
||||
for k, v in timm_to_trtllm_name.items():
|
||||
m = re.match(k, timm_name)
|
||||
if m is not None:
|
||||
if "*" in v:
|
||||
v = v.replace("*", m.groups()[0])
|
||||
return v
|
||||
return timm_name
|
||||
def get_trtllm_name(pytorch_name):
|
||||
for pytorch_name_pattern, trtllm_name_replacement in pytorch_to_trtllm_name.items():
|
||||
trtllm_name_if_matched = re.sub(pytorch_name_pattern, trtllm_name_replacement, pytorch_name)
|
||||
if trtllm_name_if_matched != pytorch_name:
|
||||
return trtllm_name_if_matched
|
||||
return pytorch_name
|
||||
|
||||
weights = dict()
|
||||
for name, param in model_params.items():
|
||||
@@ -231,7 +160,7 @@ def convert_timm_dit(args, mapping, dtype="float32"):
|
||||
|
||||
assert len(weights) == len(model_params)
|
||||
|
||||
# new_prefix = 'f5_transformer.'
|
||||
# new_prefix = "f5_transformer."
|
||||
new_prefix = ""
|
||||
weights = {new_prefix + key: value for key, value in weights.items()}
|
||||
import math
|
||||
@@ -273,19 +202,19 @@ def save_config(args):
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
config = {
|
||||
"architecture": "F5TTS",
|
||||
"architecture": "F5TTS", # set the same as in ../patch/__init__.py
|
||||
"dtype": args.dtype,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 22,
|
||||
"num_attention_heads": 16,
|
||||
"dim_head": 64,
|
||||
"dropout": 0.1,
|
||||
"ff_mult": 2,
|
||||
"hidden_size": args.hidden_size,
|
||||
"num_hidden_layers": args.depth,
|
||||
"num_attention_heads": args.num_heads,
|
||||
"dim_head": args.dim_head,
|
||||
"dropout": 0.0, # inference-only
|
||||
"ff_mult": args.ff_mult,
|
||||
"mel_dim": 100,
|
||||
"text_num_embeds": 256,
|
||||
"text_dim": 512,
|
||||
"conv_layers": 4,
|
||||
"long_skip_connection": False,
|
||||
"text_dim": args.text_dim,
|
||||
"text_mask_padding": args.text_mask_padding,
|
||||
"conv_layers": args.conv_layers,
|
||||
"pe_attn_head": args.pe_attn_head,
|
||||
"mapping": {
|
||||
"world_size": args.cp_size * args.tp_size * args.pp_size,
|
||||
"cp_size": args.cp_size,
|
||||
@@ -297,7 +226,7 @@ def save_config(args):
|
||||
config["quantization"] = {
|
||||
"quant_algo": "FP8",
|
||||
# TODO: add support for exclude modules.
|
||||
# 'exclude_modules': "*final_layer*",
|
||||
# "exclude_modules": "*final_layer*",
|
||||
}
|
||||
|
||||
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
||||
@@ -316,7 +245,7 @@ def covert_and_save(args, rank):
|
||||
pp_size=args.pp_size,
|
||||
)
|
||||
|
||||
weights = convert_timm_dit(args, mapping, dtype=args.dtype)
|
||||
weights = convert_pytorch_dit_to_trtllm_weight(args, mapping, dtype=args.dtype)
|
||||
|
||||
safetensors.torch.save_file(weights, os.path.join(args.output_dir, f"rank{rank}.safetensors"))
|
||||
|
||||
@@ -345,9 +274,9 @@ def main():
|
||||
assert args.pp_size == 1, "PP is not supported yet."
|
||||
|
||||
tik = time.time()
|
||||
if args.timm_ckpt is None:
|
||||
if args.pytorch_ckpt is None:
|
||||
return
|
||||
print("start execute")
|
||||
print("Start execute")
|
||||
execute(args.workers, [covert_and_save] * world_size, args)
|
||||
|
||||
tok = time.time()
|
||||
|
||||
@@ -12,13 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from conv_stft import STFT
|
||||
from huggingface_hub import hf_hub_download
|
||||
from vocos import Vocos
|
||||
import argparse
|
||||
|
||||
|
||||
opset_version = 17
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Manual installation of TensorRT, in case not using NVIDIA NGC:
|
||||
# https://docs.nvidia.com/deeplearning/tensorrt/latest/installing-tensorrt/installing.html#downloading-tensorrt
|
||||
TRTEXEC="/usr/src/tensorrt/bin/trtexec"
|
||||
|
||||
ONNX_PATH=$1
|
||||
@@ -28,7 +30,7 @@ MAX_BATCH_SIZE=8
|
||||
|
||||
MIN_INPUT_LENGTH=1
|
||||
OPT_INPUT_LENGTH=1000
|
||||
MAX_INPUT_LENGTH=3000
|
||||
MAX_INPUT_LENGTH=3000 # 4096
|
||||
|
||||
MEL_MIN_SHAPE="${MIN_BATCH_SIZE}x100x${MIN_INPUT_LENGTH}"
|
||||
MEL_OPT_SHAPE="${OPT_BATCH_SIZE}x100x${OPT_INPUT_LENGTH}"
|
||||
@@ -40,4 +42,3 @@ ${TRTEXEC} \
|
||||
--maxShapes="mel:${MEL_MAX_SHAPE}" \
|
||||
--onnx=${ONNX_PATH} \
|
||||
--saveEngine=${ENGINE_PATH}
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from f5_tts.model import CFM, DiT
|
||||
|
||||
import torch
|
||||
import thop
|
||||
import torch
|
||||
|
||||
from f5_tts.model import CFM, DiT
|
||||
|
||||
|
||||
""" ~155M """
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import socket
|
||||
import asyncio
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
import numpy as np
|
||||
import queue
|
||||
import socket
|
||||
import struct
|
||||
@@ -10,6 +9,7 @@ import traceback
|
||||
import wave
|
||||
from importlib.resources import files
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -18,12 +18,13 @@ from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
chunk_text,
|
||||
preprocess_ref_audio_text,
|
||||
load_vocoder,
|
||||
load_model,
|
||||
infer_batch_process,
|
||||
load_model,
|
||||
load_vocoder,
|
||||
preprocess_ref_audio_text,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# Training
|
||||
|
||||
Check your FFmpeg installation:
|
||||
```bash
|
||||
ffmpeg -version
|
||||
```
|
||||
If not found, install it first (or skip assuming you know of other backends available).
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import subprocess # For invoking ffprobe
|
||||
import shutil
|
||||
import concurrent.futures
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess # For invoking ffprobe
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import argparse
|
||||
@@ -16,12 +17,10 @@ from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import (
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
|
||||
@@ -209,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
print(f"\nSaving to {out_dir} ...")
|
||||
|
||||
# Save dataset with improved batch size for better I/O performance
|
||||
raw_arrow_path = out_dir / "raw.arrow"
|
||||
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
|
||||
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# Save durations to JSON
|
||||
dur_json_path = out_dir / "duration.json"
|
||||
|
||||
@@ -7,20 +7,18 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import (
|
||||
repetition_found,
|
||||
convert_char_to_pinyin,
|
||||
)
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, repetition_found
|
||||
|
||||
|
||||
out_zh = {
|
||||
@@ -183,6 +181,7 @@ def main():
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
|
||||
95
src/f5_tts/train/datasets/prepare_emilia_v2.py
Normal file
95
src/f5_tts/train/datasets/prepare_emilia_v2.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# put in src/f5_tts/train/datasets/prepare_emilia_v2.py
|
||||
# prepares Emilia dataset with the new format w/ Emilia-YODAS
|
||||
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import repetition_found
|
||||
|
||||
|
||||
# Define filters for exclusion
|
||||
out_en = set()
|
||||
en_filters = ["ا", "い", "て"]
|
||||
|
||||
|
||||
def process_audio_directory(audio_dir):
|
||||
sub_result, durations, vocab_set = [], [], set()
|
||||
bad_case_en = 0
|
||||
|
||||
for file in audio_dir.iterdir():
|
||||
if file.suffix == ".json":
|
||||
with open(file, "r") as f:
|
||||
obj = json.load(f)
|
||||
text = obj["text"]
|
||||
if any(f in text for f in en_filters) or repetition_found(text, length=4):
|
||||
bad_case_en += 1
|
||||
continue
|
||||
|
||||
duration = obj["duration"]
|
||||
audio_file = file.with_suffix(".mp3")
|
||||
if audio_file.exists():
|
||||
sub_result.append({"audio_path": str(audio_file), "text": text, "duration": duration})
|
||||
durations.append(duration)
|
||||
vocab_set.update(list(text))
|
||||
|
||||
return sub_result, durations, vocab_set, bad_case_en
|
||||
|
||||
|
||||
def main():
|
||||
assert tokenizer in ["pinyin", "char"]
|
||||
result, duration_list, text_vocab_set = [], [], set()
|
||||
total_bad_case_en = 0
|
||||
|
||||
executor = ProcessPoolExecutor(max_workers=max_workers)
|
||||
futures = []
|
||||
dataset_path = Path(dataset_dir)
|
||||
for sub_dir in dataset_path.iterdir():
|
||||
if sub_dir.is_dir():
|
||||
futures.append(executor.submit(process_audio_directory, sub_dir))
|
||||
|
||||
for future in tqdm(futures, total=len(futures)):
|
||||
sub_result, durations, vocab_set, bad_case_en = future.result()
|
||||
result.extend(sub_result)
|
||||
duration_list.extend(durations)
|
||||
text_vocab_set.update(vocab_set)
|
||||
total_bad_case_en += bad_case_en
|
||||
|
||||
executor.shutdown()
|
||||
|
||||
if not os.path.exists(f"{save_dir}"):
|
||||
os.makedirs(f"{save_dir}")
|
||||
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
with open(f"{save_dir}/vocab.txt", "w") as f:
|
||||
for vocab in sorted(text_vocab_set):
|
||||
f.write(vocab + "\n")
|
||||
|
||||
print(f"For {dataset_name}, sample count: {len(result)}")
|
||||
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
||||
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours")
|
||||
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
max_workers = 32
|
||||
tokenizer = "char"
|
||||
dataset_dir = "/home/ubuntu/emilia-dataset/Emilia-YODAS/EN"
|
||||
dataset_name = f"Emilia_EN_{tokenizer}"
|
||||
# save_dir = os.path.expanduser(f"~/F5-TTS/data/{dataset_name}")
|
||||
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
||||
|
||||
print(f"Prepare for {dataset_name}, will save to {save_dir}\n")
|
||||
main()
|
||||
@@ -1,15 +1,17 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
import soundfile as sf
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def deal_with_audio_dir(audio_dir):
|
||||
@@ -60,6 +62,7 @@ def main():
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
import soundfile as sf
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
@@ -37,6 +39,7 @@ def main():
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -4,15 +4,16 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import json
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from importlib.resources import files
|
||||
from tqdm import tqdm
|
||||
|
||||
import torchaudio
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from importlib.resources import files
|
||||
|
||||
from cached_path import cached_path
|
||||
|
||||
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
from f5_tts.model import CFM, DiT, Trainer, UNetT
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
# -------------------------- Dataset Settings --------------------------- #
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import gc
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import platform
|
||||
import psutil
|
||||
import queue
|
||||
import random
|
||||
import re
|
||||
import signal
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -16,21 +14,23 @@ import threading
|
||||
import time
|
||||
from glob import glob
|
||||
from importlib.resources import files
|
||||
from scipy.io import wavfile
|
||||
|
||||
import click
|
||||
import gradio as gr
|
||||
import librosa
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import torchaudio
|
||||
from cached_path import cached_path
|
||||
from datasets import Dataset as Dataset_
|
||||
from datasets.arrow_writer import ArrowWriter
|
||||
from safetensors.torch import load_file, save_file
|
||||
from scipy.io import wavfile
|
||||
|
||||
from f5_tts.api import F5TTS
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
from f5_tts.infer.utils_infer import transcribe
|
||||
from f5_tts.model.utils import convert_char_to_pinyin
|
||||
|
||||
|
||||
training_process = None
|
||||
@@ -138,6 +138,8 @@ def load_settings(project_name):
|
||||
"logger": "none",
|
||||
"bnb_optimizer": False,
|
||||
}
|
||||
if device == "mps":
|
||||
default_settings["mixed_precision"] = "none"
|
||||
|
||||
# Load settings from file if it exists
|
||||
if os.path.isfile(file_setting):
|
||||
@@ -176,50 +178,12 @@ def get_audio_duration(audio_path):
|
||||
return audio.shape[1] / sample_rate
|
||||
|
||||
|
||||
def clear_text(text):
|
||||
"""Clean and prepare text by lowering the case and stripping whitespace."""
|
||||
return text.lower().strip()
|
||||
|
||||
|
||||
def get_rms(
|
||||
y,
|
||||
frame_length=2048,
|
||||
hop_length=512,
|
||||
pad_mode="constant",
|
||||
): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
|
||||
padding = (int(frame_length // 2), int(frame_length // 2))
|
||||
y = np.pad(y, padding, mode=pad_mode)
|
||||
|
||||
axis = -1
|
||||
# put our new within-frame axis at the end for now
|
||||
out_strides = y.strides + tuple([y.strides[axis]])
|
||||
# Reduce the shape on the framing axis
|
||||
x_shape_trimmed = list(y.shape)
|
||||
x_shape_trimmed[axis] -= frame_length - 1
|
||||
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
|
||||
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
|
||||
if axis < 0:
|
||||
target_axis = axis - 1
|
||||
else:
|
||||
target_axis = axis + 1
|
||||
xw = np.moveaxis(xw, -1, target_axis)
|
||||
# Downsample along the target axis
|
||||
slices = [slice(None)] * xw.ndim
|
||||
slices[axis] = slice(0, None, hop_length)
|
||||
x = xw[tuple(slices)]
|
||||
|
||||
# Calculate power
|
||||
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
|
||||
|
||||
return np.sqrt(power)
|
||||
|
||||
|
||||
class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
|
||||
def __init__(
|
||||
self,
|
||||
sr: int,
|
||||
threshold: float = -40.0,
|
||||
min_length: int = 2000,
|
||||
min_length: int = 20000, # 20 seconds
|
||||
min_interval: int = 300,
|
||||
hop_size: int = 20,
|
||||
max_sil_kept: int = 2000,
|
||||
@@ -250,7 +214,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
|
||||
samples = waveform
|
||||
if samples.shape[0] <= self.min_length:
|
||||
return [waveform]
|
||||
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
sil_tags = []
|
||||
silence_start = None
|
||||
clip_start = 0
|
||||
@@ -304,8 +268,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
# Apply and return slices.
|
||||
####音频+起始时间+终止时间
|
||||
# Apply and return slices: [chunk, start, end]
|
||||
if len(sil_tags) == 0:
|
||||
return [[waveform, 0, int(total_frames * self.hop_size)]]
|
||||
else:
|
||||
@@ -432,7 +395,7 @@ def start_training(
|
||||
fp16 = ""
|
||||
|
||||
cmd = (
|
||||
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
|
||||
f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
|
||||
f" --learning_rate {learning_rate}"
|
||||
f" --batch_size_per_gpu {batch_size_per_gpu}"
|
||||
f" --batch_size_type {batch_size_type}"
|
||||
@@ -451,7 +414,7 @@ def start_training(
|
||||
cmd += " --finetune"
|
||||
|
||||
if file_checkpoint_train != "":
|
||||
cmd += f" --pretrain {file_checkpoint_train}"
|
||||
cmd += f' --pretrain "{file_checkpoint_train}"'
|
||||
|
||||
if tokenizer_file != "":
|
||||
cmd += f" --tokenizer_path {tokenizer_file}"
|
||||
@@ -705,7 +668,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
|
||||
|
||||
try:
|
||||
text = transcribe(file_segment, language)
|
||||
text = text.lower().strip().replace('"', "")
|
||||
text = text.strip()
|
||||
|
||||
data += f"{name_segment}|{text}\n"
|
||||
|
||||
@@ -814,7 +777,7 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
||||
error_files.append([file_audio, "very short text length 3"])
|
||||
continue
|
||||
|
||||
text = clear_text(text)
|
||||
text = text.strip()
|
||||
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
||||
|
||||
audio_path_list.append(file_audio)
|
||||
@@ -833,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
||||
min_second = round(min(duration_list), 2)
|
||||
max_second = round(max(duration_list), 2)
|
||||
|
||||
with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
|
||||
with ArrowWriter(path=file_raw) as writer:
|
||||
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
with open(file_duration, "w") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
@@ -1097,7 +1061,7 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
|
||||
|
||||
|
||||
def vocab_check(project_name):
|
||||
def vocab_check(project_name, tokenizer_type):
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
|
||||
@@ -1125,7 +1089,9 @@ def vocab_check(project_name):
|
||||
if len(sp) != 2:
|
||||
continue
|
||||
|
||||
text = sp[1].lower().strip()
|
||||
text = sp[1].strip()
|
||||
if tokenizer_type == "pinyin":
|
||||
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
||||
|
||||
for t in text:
|
||||
if t not in vocab and t not in miss_symbols_keep:
|
||||
@@ -1230,8 +1196,8 @@ def infer(
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
tts_api.infer(
|
||||
ref_file=ref_audio,
|
||||
ref_text=ref_text.lower().strip(),
|
||||
gen_text=gen_text.lower().strip(),
|
||||
ref_text=ref_text.strip(),
|
||||
gen_text=gen_text.strip(),
|
||||
nfe_step=nfe_step,
|
||||
speed=speed,
|
||||
remove_silence=remove_silence,
|
||||
@@ -1496,7 +1462,9 @@ Using the extended model, you can finetune to a new language that is missing sym
|
||||
txt_info_extend = gr.Textbox(label="Info", value="")
|
||||
|
||||
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
||||
check_button.click(
|
||||
fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
|
||||
)
|
||||
extend_button.click(
|
||||
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from f5_tts.model import CFM, Trainer
|
||||
from f5_tts.model.dataset import load_dataset
|
||||
from f5_tts.model.utils import get_tokenizer
|
||||
|
||||
|
||||
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user