47 Commits
1.1.3 ... 1.1.9

Author SHA1 Message Date
SWivid
77d3ec623b v1.1.9 2025-09-13 13:42:33 +08:00
SWivid
186799d6dc remove numpy<=1.26.4 for python_version>=3.11 #1162; update links 2025-09-13 13:40:55 +08:00
Yushen CHEN
31bb78f2ab Update badge links 2025-09-03 15:12:24 +08:00
SWivid
e61824009a v1.1.8 2025-08-28 12:33:37 +00:00
SWivid
06a74910bd add option for text embedding late average upsampling 2025-08-28 11:46:11 +00:00
Yushen CHEN
ac3c43595c delete .github/workflows/sync-hf.yaml for online space stablility 2025-08-27 06:52:18 +08:00
Jim
605fa13b42 Fix raw.arrow missing rows (#1145)
* fix raw.arrow missing rows

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-07-22 19:38:44 +08:00
Yushen CHEN
5f35f27230 update pyproject.toml 2025-07-15 17:28:41 +08:00
Yushen CHEN
c96c3aeed8 Update pyproject.toml 2025-07-14 14:36:26 +08:00
Yushen CHEN
9b60fe6a34 update pyproject.toml, set gradio<=5.35.0 until fix #1126 2025-07-14 14:29:19 +08:00
SWivid
a275798a2f last fix patch-1 2025-07-08 18:44:47 +08:00
SWivid
efc7a7498b fix #1111 #1037 remove redundant unwrap_model for AcceleratedOptimizer; which has no attribute '_modules' thus conflict with has_compiled_regions check introduced in accelerate v1.7.0 2025-07-08 18:39:43 +08:00
SWivid
9842314127 update slicer in finetune_gradio, legacy min_length 2s changed to 20s 2025-07-08 16:59:46 +08:00
SWivid
69b0e0110e v1.1.6 fla support, several changed for finetune and infer-cli 2025-07-03 00:08:42 +08:00
SWivid
52c84776e5 fine-grained speed control for infer-cli. #1112 2025-07-02 23:41:55 +08:00
Danh Tran
ebbd7bd91f Update WAV File Naming and Dependencies 📝🔊 (#1091)
* Update infer_cli.py

* Update pyproject.toml

* formalized

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-24 23:23:00 +08:00
Yushen CHEN
ac42286d04 update finetune_gradio.py, not to force lower case
Not to force lower case, otherwise train infer mismatch with main infer code
2025-06-23 16:37:51 +08:00
Yushen CHEN
d937efa6f3 fix finetune_gradio.py, not to force lower case 2025-06-23 16:22:33 +08:00
Yushen CHEN
8975fca803 Merge pull request #1084 from starkwj/main
Speedup inference by batching CFG in DiT
2025-06-12 03:54:04 +08:00
SWivid
8b0053ad0c backward compatibility 2025-06-12 03:52:12 +08:00
SWivid
b3ef4ed1d7 correct imple., minor fixes 2025-06-12 03:32:19 +08:00
starkwj
b1a9438496 Batch cfg DiT forward 2025-06-11 09:03:30 +00:00
Zhikang Niu
0914170e98 Add flash_attn2 support attn_mask, minor fixes (#1066)
* add flash attn2 support
* update flash attn config in F5TTS
* fix minor bug of get the length of ref_mel

---------

Co-authored-by: SWivid <swivid@qq.com>
2025-06-11 12:14:32 +08:00
SWivid
c6ebad0220 switch sync-hf workflow logic on release, avoid hidden space error with pypi/local_editable mismatch 2025-06-06 07:23:54 +08:00
SWivid
cfaba6387f refresh hf-space first 2025-06-06 07:22:02 +08:00
SWivid
646f34b20f v1.1.5 pypi 2025-06-06 07:08:59 +08:00
Jerrister Zheng
2e2acc6ea2 Update: Empirically Pruned Step Sampling (#1077)
* update Empirically Pruned Step Sampling

---------

Co-authored-by: Fast-F5-TTS <2942755472@qq.com>
Co-authored-by: SWivid <swivid@qq.com>
2025-06-04 22:59:30 +08:00
SWivid
6fbe7592f5 rebase default sample_rate to 24khz for runtime 2025-06-04 11:22:31 +08:00
Alice Yanagi
7e37bc5d9a Fix the duration computation in triton_trtllm/client_grpc.py (#1071)
* Update client_grpc.py

Using `actual_duration` to compute metrics like RTF.
2025-06-04 11:18:00 +08:00
SWivid
35f130ee85 minor update for infer-gradio 2025-06-04 06:11:49 +08:00
SWivid
e6469f705f update shared.md 2025-06-03 22:09:13 +08:00
SWivid
31cd818095 formatting 2025-06-03 21:23:47 +08:00
Yushen CHEN
1d13664b24 Merge pull request #1063 from ionite34/dev
Fix finetune training with spaces in file paths
2025-06-03 21:18:41 +08:00
Yushen CHEN
b27471ea06 Merge pull request #1072 from hvoss-techfak/main
German Model support
2025-06-03 21:18:25 +08:00
Hendric Voss
8fb55f107e Update SHARED.md 2025-06-03 14:08:30 +02:00
Hendric Voss
ccb380b752 Added German Model 2025-06-03 14:08:03 +02:00
Ionite
3027b43953 Fix training with file path spaces 2025-05-28 15:24:35 -04:00
SWivid
ecd1c3949a Add py312 check for tempfile delete_on_close keyword 2025-05-22 23:10:29 +08:00
SWivid
2968aa184f v1.1.5 several fixes 2025-05-22 17:41:10 +08:00
SWivid
fb26b6d93e Fix #1046 tempfile related bug 2025-05-22 17:40:14 +08:00
SWivid
f7f266cdd9 preprocess only once. Fix #1043 2025-05-21 02:26:05 +08:00
SWivid
695c735737 Exclude broken dependency version with accelerate 2025-05-16 17:48:41 +08:00
SWivid
3e2a07da1d Update README.md & minor fixes 2025-05-11 19:40:37 +08:00
SWivid
c47687487c minor fix for vocab check in finetune_gradio 2025-05-05 23:32:00 +08:00
SWivid
ac79d0ec1e v1.1.4 2025-05-05 04:05:25 +08:00
SWivid
dad398c0c1 Bug Fix #1015
Ensure custom config hashable in
2025-05-05 03:55:05 +08:00
SWivid
3d969bf78d minor fix for backward compatibility to gradio multistyle feature 2025-05-05 02:07:19 +08:00
33 changed files with 541 additions and 250 deletions

View File

@@ -1,18 +0,0 @@
name: Sync to HF Space
on:
push:
branches:
- main
jobs:
trigger_curl:
runs-on: ubuntu-latest
steps:
- name: Send cURL POST request
run: |
curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
-s \
-H "Content-Type: application/json" \
-d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"

View File

@@ -2,11 +2,12 @@
[![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
[![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
[![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://swivid.github.io/F5-TTS/)
[![hfspace](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
[![msspace](https://img.shields.io/badge/🤖-Space%20demo-blue)](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
[![lab](https://img.shields.io/badge/X--LANCE-Lab-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
[![lab](https://img.shields.io/badge/Peng%20Cheng-Lab-grey?labelColor=lightgrey)](https://www.pcl.ac.cn)
[![demo](https://img.shields.io/badge/GitHub-Demo-orange.svg)](https://swivid.github.io/F5-TTS/)
[![hfspace](https://img.shields.io/badge/🤗-HF%20Space-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
[![msspace](https://img.shields.io/badge/🤖-MS%20Space-blue)](https://modelscope.cn/studios/AI-ModelScope/E2-F5-TTS)
[![lab](https://img.shields.io/badge/🏫-X--LANCE-grey?labelColor=lightgrey)](https://x-lance.sjtu.edu.cn/)
[![lab](https://img.shields.io/badge/🏫-SII-grey?labelColor=lightgrey)](https://www.sii.edu.cn/)
[![lab](https://img.shields.io/badge/🏫-PCL-grey?labelColor=lightgrey)](https://www.pcl.ac.cn)
<!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
@@ -26,8 +27,8 @@
### Create a separate environment if needed
```bash
# Create a python 3.10 conda env (you could also use virtualenv)
conda create -n f5-tts python=3.10
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
conda create -n f5-tts python=3.11
conda activate f5-tts
```
@@ -91,7 +92,7 @@ conda activate f5-tts
> ```bash
> git clone https://github.com/SWivid/F5-TTS.git
> cd F5-TTS
> # git submodule update --init --recursive # (optional, if need > bigvgan)
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
> pip install -e .
> ```

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.3"
version = "1.1.9"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -15,17 +15,17 @@ classifiers = [
]
dependencies = [
"accelerate>=0.33.0",
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
"bitsandbytes>0.37.0; platform_machine!='arm64' and platform_system!='Darwin'",
"cached_path",
"click",
"datasets",
"ema_pytorch>=0.5.2",
"gradio>=3.45.2",
"gradio>=5.0.0",
"hydra-core>=1.3.0",
"jieba",
"librosa",
"matplotlib",
"numpy<=1.26.4",
"numpy<=1.26.4; python_version<='3.10'",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
@@ -38,6 +38,7 @@ dependencies = [
"tqdm>=4.65.0",
"transformers",
"transformers_stream_generator",
"unidecode",
"vocos",
"wandb",
"x_transformers>=1.31.14",

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -31,6 +31,8 @@ model:
text_mask_padding: False
conv_layers: 4
pe_attn_head: 1
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -32,6 +32,8 @@ model:
qk_norm: null # null | rms_norm
conv_layers: 4
pe_attn_head: null
attn_backend: torch # torch | flash_attn
attn_mask_enabled: False
checkpoint_activations: False # recompute activations and save memory for extra compute
mel_spec:
target_sample_rate: 24000

View File

@@ -148,10 +148,15 @@ def main():
vocab_char_map=vocab_char_map,
).to(device)
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
if not os.path.exists(ckpt_path):
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
if os.path.exists(ckpt_prefix + ".pt"):
ckpt_path = ckpt_prefix + ".pt"
elif os.path.exists(ckpt_prefix + ".safetensors"):
ckpt_path = ckpt_prefix + ".safetensors"
else:
print("Loading from self-organized training checkpoints rather than released pretrained.")
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)

View File

@@ -126,8 +126,13 @@ def get_inference_prompt(
else:
text_list = text
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
ref_mel_len = ref_mel.shape[-1]
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
@@ -142,10 +147,6 @@ def get_inference_prompt(
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert min_tokens <= total_mel_len <= max_tokens, (

View File

@@ -13,7 +13,7 @@ To avoid possible inference failures, make sure you have seen through the follow
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
@@ -129,6 +129,28 @@ ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
## API Usage
```python
from importlib.resources import files
from f5_tts.api import F5TTS
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
```
Check [api.py](../api.py) for more details.
## TensorRT-LLM Deployment
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
## Socket Real-time Service
Real-time voice output with chunk stream:

View File

@@ -22,6 +22,8 @@
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
- [French](#french)
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
- [German](#german)
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
- [Hindi](#hindi)
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
- [Italian](#italian)
@@ -97,6 +99,22 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
## German
#### F5-TTS Base @ de @ hvoss-techfak
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
```bash
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
## Hindi
#### F5-TTS Small @ hi @ SPRINGLab

View File

@@ -13,8 +13,8 @@ output_file = "infer_cli_story.wav"
[voices.town]
ref_audio = "infer/examples/multi/town.flac"
ref_text = ""
speed = 0.8 # will ignore global speed
[voices.country]
ref_audio = "infer/examples/multi/country.flac"
ref_text = ""

View File

@@ -1 +1 @@
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land. [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] Goodbye, [main] said he, [country] “Im off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."

View File

@@ -12,6 +12,7 @@ import tomli
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from unidecode import unidecode
from f5_tts.infer.utils_infer import (
cfg_strength,
@@ -112,6 +113,11 @@ parser.add_argument(
action="store_true",
help="To save each audio chunks during inference",
)
parser.add_argument(
"--no_legacy_text",
action="store_false",
help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
)
parser.add_argument(
"--remove_silence",
action="store_true",
@@ -197,6 +203,12 @@ output_file = args.output_file or config.get(
)
save_chunk = args.save_chunk or config.get("save_chunk", False)
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
if save_chunk and use_legacy_text:
print(
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
)
remove_silence = args.remove_silence or config.get("remove_silence", False)
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
@@ -321,9 +333,10 @@ def main():
text = re.sub(reg2, "", text)
ref_audio_ = voices[voice]["ref_audio"]
ref_text_ = voices[voice]["ref_text"]
local_speed = voices[voice].get("speed", speed)
gen_text_ = text.strip()
print(f"Voice: {voice}")
audio_segment, final_sample_rate, spectragram = infer_process(
audio_segment, final_sample_rate, spectrogram = infer_process(
ref_audio_,
ref_text_,
gen_text_,
@@ -335,7 +348,7 @@ def main():
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
speed=local_speed,
fix_duration=fix_duration,
device=device,
)
@@ -344,6 +357,8 @@ def main():
if save_chunk:
if len(gen_text_) > 200:
gen_text_ = gen_text_[:200] + " ... "
if use_legacy_text:
gen_text_ = unidecode(gen_text_)
sf.write(
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
audio_segment,

View File

@@ -3,6 +3,7 @@
import gc
import json
import os
import re
import tempfile
from collections import OrderedDict
@@ -41,6 +42,7 @@ from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
tempfile_kwargs,
)
from f5_tts.model import DiT, UNetT
@@ -80,6 +82,8 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
vocab_path = str(cached_path(vocab_path))
if model_cfg is None:
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
elif isinstance(model_cfg, str):
model_cfg = json.loads(model_cfg)
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
@@ -124,7 +128,7 @@ def load_text_from_file(file):
return gr.update(value=text)
@lru_cache(maxsize=100)
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
@gpu_decorator
def infer(
ref_audio_orig,
@@ -163,7 +167,7 @@ def infer(
show_info("Loading E2-TTS model...")
E2TTS_ema_model = load_e2tts()
ema_model = E2TTS_ema_model
elif isinstance(model, list) and model[0] == "Custom":
elif isinstance(model, tuple) and model[0] == "Custom":
assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
global custom_ema_model, pre_custom_path
if pre_custom_path != model[1]:
@@ -187,28 +191,24 @@ def infer(
# Remove silence
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
try:
sf.write(temp_path, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
finally:
os.unlink(temp_path)
final_wave = final_wave.squeeze().cpu().numpy()
# Save the spectrogram
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")
with gr.Blocks() as app_tts:
gr.Markdown("# Batched TTS")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
@@ -312,6 +312,12 @@ with gr.Blocks() as app_tts:
outputs=[ref_text_input],
)
ref_audio_input.clear(
lambda: [None, None],
None,
[ref_text_input, ref_text_file],
)
generate_btn.click(
basic_tts,
inputs=[
@@ -357,6 +363,7 @@ def parse_speechtypes_text(gen_text):
try: # if type dict
current_type_dict = json.loads(type_str)
except json.decoder.JSONDecodeError:
type_str = type_str[1:-1] # remove brace {}
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
return segments
@@ -923,12 +930,22 @@ Have a conversation with an AI using your reference voice!
)
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")
with gr.Blocks() as app:
gr.Markdown(
f"""
# E2/F5 TTS
# F5-TTS Demo Space
This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
This is {"a local web UI for [F5-TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
@@ -958,7 +975,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
global tts_model_choice
if new_choice == "Custom": # override in case webpage is refreshed
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
return (
gr.update(visible=True, value=custom_ckpt_path),
gr.update(visible=True, value=custom_vocab_path),
@@ -970,7 +987,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
global tts_model_choice
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
with open(last_used_custom, "w", encoding="utf-8") as f:
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")

View File

@@ -33,6 +33,7 @@ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}
_ref_text_cache = {}
device = (
"cuda"
@@ -44,6 +45,8 @@ device = (
else "cpu"
)
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
# -----------------------------------------
target_sample_rate = 24000
@@ -290,62 +293,74 @@ def remove_silence_edges(audio, silence_threshold=-42):
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
# Compute a hash of the reference audio file
with open(ref_audio_orig, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
show_info("Using cached preprocessed reference audio...")
ref_audio = _ref_audio_cache[audio_hash]
else: # first pass, do preprocess
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
aseg = AudioSegment.from_file(ref_audio_orig)
if clip_short:
# 1. try to find long silence for clipping
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
ref_audio = f.name
aseg.export(temp_path, format="wav")
ref_audio = temp_path
# Compute a hash of the reference audio file
with open(ref_audio, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
# Cache the processed reference audio
_ref_audio_cache[audio_hash] = ref_audio
if not ref_text.strip():
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
global _ref_text_cache
if audio_hash in _ref_text_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
ref_text = _ref_text_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_audio_cache[audio_hash] = ref_text
_ref_text_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
@@ -384,7 +399,7 @@ def infer_process(
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)

View File

@@ -29,11 +29,16 @@ from f5_tts.model.modules import (
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
def __init__(
self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2
):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder)
if average_upsampling:
assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True"
if conv_layers > 0:
self.extra_modeling = True
@@ -45,11 +50,47 @@ class TextEmbedding(nn.Module):
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
batch, text_len, text_dim = text.shape
if audio_mask is None:
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
valid_mask = audio_mask & text_mask
audio_lens = audio_mask.sum(dim=1) # [batch]
valid_lens = valid_mask.sum(dim=1) # [batch]
upsampled_text = torch.zeros_like(text)
for i in range(batch):
audio_len = audio_lens[i].item()
valid_len = valid_lens[i].item()
if valid_len == 0:
continue
valid_ind = torch.where(valid_mask[i])[0]
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
base_repeat = audio_len // valid_len
remainder = audio_len % valid_len
indices = []
for j in range(valid_len):
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
indices.extend([j] * repeat_count)
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
upsampled = valid_data[indices] # [audio_len, text_dim]
upsampled_text[i, :audio_len, :] = upsampled
return upsampled_text
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
if self.mask_padding:
text_mask = text == 0
@@ -61,7 +102,7 @@ class TextEmbedding(nn.Module):
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
@@ -75,6 +116,9 @@ class TextEmbedding(nn.Module):
else:
text = self.text_blocks(text)
if self.average_upsampling:
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
return text
@@ -113,9 +157,12 @@ class DiT(nn.Module):
text_num_embeds=256,
text_dim=None,
text_mask_padding=True,
text_embedding_average_upsampling=False,
qk_norm=None,
conv_layers=0,
pe_attn_head=None,
attn_backend="torch", # "torch" | "flash_attn"
attn_mask_enabled=False,
long_skip_connection=False,
checkpoint_activations=False,
):
@@ -125,7 +172,11 @@ class DiT(nn.Module):
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
text_num_embeds,
text_dim,
mask_padding=text_mask_padding,
average_upsampling=text_embedding_average_upsampling,
conv_layers=conv_layers,
)
self.text_cond, self.text_uncond = None, None # text cache
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
@@ -145,6 +196,8 @@ class DiT(nn.Module):
dropout=dropout,
qk_norm=qk_norm,
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for _ in range(depth)
]
@@ -178,6 +231,33 @@ class DiT(nn.Module):
return ckpt_forward
def get_input_embed(
self,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
audio_mask: bool["b n"] | None = None, # noqa: F722
):
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
@@ -187,10 +267,11 @@ class DiT(nn.Module):
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
@@ -198,18 +279,20 @@ class DiT(nn.Module):
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
text_embed = self.text_cond
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask
)
x_uncond = self.get_input_embed(
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask
)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
x = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask
)
rope = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@@ -141,26 +141,15 @@ class MMDiT(nn.Module):
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cache:
if drop_text:
if self.text_uncond is None:
@@ -174,6 +163,41 @@ class MMDiT(nn.Module):
c = self.text_embed(text, drop_text=drop_text)
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
return x, c
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch = x.shape[0]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
c = torch.cat((c_cond, c_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x, c = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
)
seq_len = x.shape[1]
text_len = text.shape[1]
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)

View File

@@ -178,26 +178,16 @@ class UNetT(nn.Module):
self.norm_out = RMSNorm(dim)
self.proj_out = nn.Linear(dim, mel_dim)
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
def get_input_embed(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
cache=False,
x, # b n d
cond, # b n d
text, # b nt
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
seq_len = x.shape[1]
if cache:
if drop_text:
if self.text_uncond is None:
@@ -209,8 +199,41 @@ class UNetT(nn.Module):
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
return x
def clear_cache(self):
self.text_cond, self.text_uncond = None, None
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
cache: bool = False,
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
else:
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
# postfix time t to input x, [b n d] -> [b n+1 d]
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
if mask is not None:

View File

@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
get_epss_timesteps,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
@@ -92,6 +93,7 @@ class CFM(nn.Module):
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
use_epss=True,
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
@@ -160,16 +162,31 @@ class CFM(nn.Module):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
)
# predict flow (cond)
if cfg_strength < 1e-5:
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
cache=True,
)
return pred
null_pred = self.transformer(
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
# predict flow (cond and uncond), for classifier-free guidance
pred_cfg = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
cfg_infer=True,
cache=True,
)
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
return pred + (pred - null_pred) * cfg_strength
# noise input
@@ -190,7 +207,10 @@ class CFM(nn.Module):
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
else:
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
@@ -270,10 +290,9 @@ class CFM(nn.Module):
else:
drop_text = False
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
pred = self.transformer(
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
)
# flow matching loss

View File

@@ -312,7 +312,7 @@ def collate_fn(batch):
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
@@ -324,7 +324,7 @@ def collate_fn(batch):
return dict(
mel=mel_specs,
mel_lengths=mel_lengths,
mel_lengths=mel_lengths, # records for padding mask
text=text,
text_lengths=text_lengths,
)

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
from __future__ import annotations
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
from f5_tts.model.utils import is_package_available
# raw wav to mel spec
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
@@ -417,9 +420,9 @@ class Attention(nn.Module):
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b n d"] = None, # context c
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
@@ -431,19 +434,30 @@ class Attention(nn.Module):
# Attention processor
if is_package_available("flash_attn"):
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:
def __init__(
self,
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
self.pe_attn_head = pe_attn_head
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
@@ -479,16 +493,40 @@ class AttnProcessor:
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
@@ -514,9 +552,9 @@ class JointAttnProcessor:
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # noised input x
c: float["b nt d"] = None, # context c, here text
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
@@ -608,12 +646,27 @@ class JointAttnProcessor:
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
def __init__(
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
qk_norm=None,
pe_attn_head=None,
attn_backend="torch", # "torch" or "flash_attn"
attn_mask_enabled=True,
):
super().__init__()
self.attn_norm = AdaLayerNorm(dim)
self.attn = Attention(
processor=AttnProcessor(pe_attn_head=pe_attn_head),
processor=AttnProcessor(
pe_attn_head=pe_attn_head,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
def forward(self, timestep: float["b"]):
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d

View File

@@ -149,7 +149,7 @@ class Trainer:
if self.is_main:
checkpoint = dict(
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
optimizer_state_dict=self.optimizer.state_dict(),
ema_model_state_dict=self.ema_model.state_dict(),
scheduler_state_dict=self.scheduler.state_dict(),
update=update,
@@ -242,7 +242,7 @@ class Trainer:
del checkpoint["model_state_dict"][key]
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
update = checkpoint["update"]

View File

@@ -35,6 +35,16 @@ def default(v, d):
return v if exists(v) else d
def is_package_available(package_name: str) -> bool:
try:
import importlib
package_exists = importlib.util.find_spec(package_name) is not None
return package_exists
except Exception:
return False
# tensor helpers
@@ -189,3 +199,22 @@ def repetition_found(text, length=2, tolerance=10):
if count > tolerance:
return True
return False
# get the empirically pruned step for sampling
def get_epss_timesteps(n, device, dtype):
dt = 1 / 32
predefined_timesteps = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = predefined_timesteps.get(n, [])
if not t:
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
return dt * torch.tensor(t, device=device, dtype=dtype)

View File

@@ -220,8 +220,8 @@ def get_args():
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
@@ -244,7 +244,7 @@ async def send(
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
save_sample_rate: int = 24000,
):
total_duration = 0.0
latency_data = []
@@ -254,7 +254,7 @@ async def send(
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
@@ -310,8 +310,9 @@ async def send(
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, estimated_target_duration))
total_duration += estimated_target_duration
actual_duration = len(audio) / save_sample_rate
latency_data.append((end, actual_duration))
total_duration += actual_duration
return total_duration, latency_data
@@ -416,7 +417,7 @@ async def main():
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
save_sample_rate=24000,
)
)
tasks.append(task)

View File

@@ -82,7 +82,7 @@ def prepare_request(
samples,
reference_text,
target_text,
sample_rate=16000,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
@@ -106,8 +106,8 @@ def prepare_request(
return data
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
@@ -129,7 +129,7 @@ if __name__ == "__main__":
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)

View File

@@ -33,7 +33,7 @@ parameters [
},
{
key: "reference_audio_sample_rate",
value: {string_value:"16000"}
value: {string_value:"24000"}
},
{
key: "vocoder",

View File

@@ -1,5 +1,11 @@
# Training
Check your FFmpeg installation:
```bash
ffmpeg -version
```
If not found, install it first (or skip assuming you know of other backends available).
## Prepare Dataset
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.

View File

@@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
out_dir.mkdir(exist_ok=True, parents=True)
print(f"\nSaving to {out_dir} ...")
# Save dataset with improved batch size for better I/O performance
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# Save durations to JSON
dur_json_path = out_dir / "duration.json"

View File

@@ -181,6 +181,7 @@ def main():
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:

View File

@@ -68,6 +68,7 @@ def main():
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)

View File

@@ -62,6 +62,7 @@ def main():
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:

View File

@@ -39,6 +39,7 @@ def main():
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
for line in tqdm(result, desc="Writing to raw.arrow ..."):
writer.write(line)
writer.finalize()
# dup a json separately saving duration in case for DynamicBatchSampler ease
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:

View File

@@ -178,50 +178,12 @@ def get_audio_duration(audio_path):
return audio.shape[1] / sample_rate
def clear_text(text):
"""Clean and prepare text by lowering the case and stripping whitespace."""
return text.lower().strip()
def get_rms(
y,
frame_length=2048,
hop_length=512,
pad_mode="constant",
): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
padding = (int(frame_length // 2), int(frame_length // 2))
y = np.pad(y, padding, mode=pad_mode)
axis = -1
# put our new within-frame axis at the end for now
out_strides = y.strides + tuple([y.strides[axis]])
# Reduce the shape on the framing axis
x_shape_trimmed = list(y.shape)
x_shape_trimmed[axis] -= frame_length - 1
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
if axis < 0:
target_axis = axis - 1
else:
target_axis = axis + 1
xw = np.moveaxis(xw, -1, target_axis)
# Downsample along the target axis
slices = [slice(None)] * xw.ndim
slices[axis] = slice(0, None, hop_length)
x = xw[tuple(slices)]
# Calculate power
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
return np.sqrt(power)
class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
def __init__(
self,
sr: int,
threshold: float = -40.0,
min_length: int = 2000,
min_length: int = 20000, # 20 seconds
min_interval: int = 300,
hop_size: int = 20,
max_sil_kept: int = 2000,
@@ -252,7 +214,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
samples = waveform
if samples.shape[0] <= self.min_length:
return [waveform]
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
sil_tags = []
silence_start = None
clip_start = 0
@@ -306,8 +268,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
silence_end = min(total_frames, silence_start + self.max_sil_kept)
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
sil_tags.append((pos, total_frames + 1))
# Apply and return slices.
####音频+起始时间+终止时间
# Apply and return slices: [chunk, start, end]
if len(sil_tags) == 0:
return [[waveform, 0, int(total_frames * self.hop_size)]]
else:
@@ -434,7 +395,7 @@ def start_training(
fp16 = ""
cmd = (
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
f" --learning_rate {learning_rate}"
f" --batch_size_per_gpu {batch_size_per_gpu}"
f" --batch_size_type {batch_size_type}"
@@ -453,7 +414,7 @@ def start_training(
cmd += " --finetune"
if file_checkpoint_train != "":
cmd += f" --pretrain {file_checkpoint_train}"
cmd += f' --pretrain "{file_checkpoint_train}"'
if tokenizer_file != "":
cmd += f" --tokenizer_path {tokenizer_file}"
@@ -707,7 +668,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
try:
text = transcribe(file_segment, language)
text = text.lower().strip().replace('"', "")
text = text.strip()
data += f"{name_segment}|{text}\n"
@@ -816,7 +777,7 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
error_files.append([file_audio, "very short text length 3"])
continue
text = clear_text(text)
text = text.strip()
text = convert_char_to_pinyin([text], polyphone=True)[0]
audio_path_list.append(file_audio)
@@ -835,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
min_second = round(min(duration_list), 2)
max_second = round(max(duration_list), 2)
with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
with ArrowWriter(path=file_raw) as writer:
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
writer.write(line)
writer.finalize()
with open(file_duration, "w") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
@@ -1099,7 +1061,7 @@ def vocab_extend(project_name, symbols, model_type):
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
def vocab_check(project_name):
def vocab_check(project_name, tokenizer_type):
name_project = project_name
path_project = os.path.join(path_data, name_project)
@@ -1127,7 +1089,9 @@ def vocab_check(project_name):
if len(sp) != 2:
continue
text = sp[1].lower().strip()
text = sp[1].strip()
if tokenizer_type == "pinyin":
text = convert_char_to_pinyin([text], polyphone=True)[0]
for t in text:
if t not in vocab and t not in miss_symbols_keep:
@@ -1232,8 +1196,8 @@ def infer(
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
tts_api.infer(
ref_file=ref_audio,
ref_text=ref_text.lower().strip(),
gen_text=gen_text.lower().strip(),
ref_text=ref_text.strip(),
gen_text=gen_text.strip(),
nfe_step=nfe_step,
speed=speed,
remove_silence=remove_silence,
@@ -1498,7 +1462,9 @@ Using the extended model, you can finetune to a new language that is missing sym
txt_info_extend = gr.Textbox(label="Info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
check_button.click(
fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
)
extend_button.click(
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
)