mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77d3ec623b | ||
|
|
186799d6dc | ||
|
|
31bb78f2ab | ||
|
|
e61824009a | ||
|
|
06a74910bd | ||
|
|
ac3c43595c | ||
|
|
605fa13b42 | ||
|
|
5f35f27230 | ||
|
|
c96c3aeed8 | ||
|
|
9b60fe6a34 | ||
|
|
a275798a2f | ||
|
|
efc7a7498b | ||
|
|
9842314127 | ||
|
|
69b0e0110e | ||
|
|
52c84776e5 | ||
|
|
ebbd7bd91f | ||
|
|
ac42286d04 | ||
|
|
d937efa6f3 | ||
|
|
8975fca803 | ||
|
|
8b0053ad0c | ||
|
|
b3ef4ed1d7 | ||
|
|
b1a9438496 | ||
|
|
0914170e98 | ||
|
|
c6ebad0220 | ||
|
|
cfaba6387f | ||
|
|
646f34b20f | ||
|
|
2e2acc6ea2 | ||
|
|
6fbe7592f5 | ||
|
|
7e37bc5d9a | ||
|
|
35f130ee85 | ||
|
|
e6469f705f | ||
|
|
31cd818095 | ||
|
|
1d13664b24 | ||
|
|
b27471ea06 | ||
|
|
8fb55f107e | ||
|
|
ccb380b752 | ||
|
|
3027b43953 | ||
|
|
ecd1c3949a | ||
|
|
2968aa184f | ||
|
|
fb26b6d93e | ||
|
|
f7f266cdd9 | ||
|
|
695c735737 | ||
|
|
3e2a07da1d | ||
|
|
c47687487c | ||
|
|
ac79d0ec1e | ||
|
|
dad398c0c1 | ||
|
|
3d969bf78d |
18
.github/workflows/sync-hf.yaml
vendored
18
.github/workflows/sync-hf.yaml
vendored
@@ -1,18 +0,0 @@
|
||||
name: Sync to HF Space
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
trigger_curl:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Send cURL POST request
|
||||
run: |
|
||||
curl -X POST https://mrfakename-sync-f5.hf.space/gradio_api/call/refresh \
|
||||
-s \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"data\": [\"${{ secrets.REFRESH_PASSWORD }}\"]}"
|
||||
17
README.md
17
README.md
@@ -2,11 +2,12 @@
|
||||
|
||||
[](https://github.com/SWivid/F5-TTS)
|
||||
[](https://arxiv.org/abs/2410.06885)
|
||||
[](https://swivid.github.io/F5-TTS/)
|
||||
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
[](https://modelscope.cn/studios/modelscope/E2-F5-TTS)
|
||||
[](https://x-lance.sjtu.edu.cn/)
|
||||
[](https://www.pcl.ac.cn)
|
||||
[](https://swivid.github.io/F5-TTS/)
|
||||
[](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
[](https://modelscope.cn/studios/AI-ModelScope/E2-F5-TTS)
|
||||
[](https://x-lance.sjtu.edu.cn/)
|
||||
[](https://www.sii.edu.cn/)
|
||||
[](https://www.pcl.ac.cn)
|
||||
<!-- <img src="https://github.com/user-attachments/assets/12d7749c-071a-427c-81bf-b87b91def670" alt="Watermark" style="width: 40px; height: auto"> -->
|
||||
|
||||
**F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
|
||||
@@ -26,8 +27,8 @@
|
||||
### Create a separate environment if needed
|
||||
|
||||
```bash
|
||||
# Create a python 3.10 conda env (you could also use virtualenv)
|
||||
conda create -n f5-tts python=3.10
|
||||
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
|
||||
conda create -n f5-tts python=3.11
|
||||
conda activate f5-tts
|
||||
```
|
||||
|
||||
@@ -91,7 +92,7 @@ conda activate f5-tts
|
||||
> ```bash
|
||||
> git clone https://github.com/SWivid/F5-TTS.git
|
||||
> cd F5-TTS
|
||||
> # git submodule update --init --recursive # (optional, if need > bigvgan)
|
||||
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
|
||||
> pip install -e .
|
||||
> ```
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.3"
|
||||
version = "1.1.9"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -15,17 +15,17 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"accelerate>=0.33.0",
|
||||
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
|
||||
"bitsandbytes>0.37.0; platform_machine!='arm64' and platform_system!='Darwin'",
|
||||
"cached_path",
|
||||
"click",
|
||||
"datasets",
|
||||
"ema_pytorch>=0.5.2",
|
||||
"gradio>=3.45.2",
|
||||
"gradio>=5.0.0",
|
||||
"hydra-core>=1.3.0",
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy<=1.26.4",
|
||||
"numpy<=1.26.4; python_version<='3.10'",
|
||||
"pydantic<=2.10.6",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
@@ -38,6 +38,7 @@ dependencies = [
|
||||
"tqdm>=4.65.0",
|
||||
"transformers",
|
||||
"transformers_stream_generator",
|
||||
"unidecode",
|
||||
"vocos",
|
||||
"wandb",
|
||||
"x_transformers>=1.31.14",
|
||||
|
||||
@@ -31,6 +31,8 @@ model:
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -31,6 +31,8 @@ model:
|
||||
text_mask_padding: False
|
||||
conv_layers: 4
|
||||
pe_attn_head: 1
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -32,6 +32,8 @@ model:
|
||||
qk_norm: null # null | rms_norm
|
||||
conv_layers: 4
|
||||
pe_attn_head: null
|
||||
attn_backend: torch # torch | flash_attn
|
||||
attn_mask_enabled: False
|
||||
checkpoint_activations: False # recompute activations and save memory for extra compute
|
||||
mel_spec:
|
||||
target_sample_rate: 24000
|
||||
|
||||
@@ -148,10 +148,15 @@ def main():
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
|
||||
if not os.path.exists(ckpt_path):
|
||||
ckpt_prefix = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}"
|
||||
if os.path.exists(ckpt_prefix + ".pt"):
|
||||
ckpt_path = ckpt_prefix + ".pt"
|
||||
elif os.path.exists(ckpt_prefix + ".safetensors"):
|
||||
ckpt_path = ckpt_prefix + ".safetensors"
|
||||
else:
|
||||
print("Loading from self-organized training checkpoints rather than released pretrained.")
|
||||
ckpt_path = rel_path + f"/{model_cfg.ckpts.save_dir}/model_{ckpt_step}.pt"
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
||||
|
||||
|
||||
@@ -126,8 +126,13 @@ def get_inference_prompt(
|
||||
else:
|
||||
text_list = text
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# Duration, mel frame length
|
||||
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||
ref_mel_len = ref_mel.shape[-1]
|
||||
|
||||
if use_truth_duration:
|
||||
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||
if gt_sr != target_sample_rate:
|
||||
@@ -142,10 +147,6 @@ def get_inference_prompt(
|
||||
gen_text_len = len(gt_text.encode("utf-8"))
|
||||
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
||||
|
||||
# to mel spectrogram
|
||||
ref_mel = mel_spectrogram(ref_audio)
|
||||
ref_mel = ref_mel.squeeze(0)
|
||||
|
||||
# deal with batch
|
||||
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||
assert min_tokens <= total_mel_len <= max_tokens, (
|
||||
|
||||
@@ -13,7 +13,7 @@ To avoid possible inference failures, make sure you have seen through the follow
|
||||
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
|
||||
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
|
||||
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
||||
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
|
||||
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
|
||||
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
|
||||
|
||||
|
||||
@@ -129,6 +129,28 @@ ref_text = ""
|
||||
```
|
||||
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
||||
|
||||
## API Usage
|
||||
|
||||
```python
|
||||
from importlib.resources import files
|
||||
from f5_tts.api import F5TTS
|
||||
|
||||
f5tts = F5TTS()
|
||||
wav, sr, spec = f5tts.infer(
|
||||
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
||||
ref_text="some call me nature, others call me mother nature.",
|
||||
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
||||
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
||||
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
||||
seed=None,
|
||||
)
|
||||
```
|
||||
Check [api.py](../api.py) for more details.
|
||||
|
||||
## TensorRT-LLM Deployment
|
||||
|
||||
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
|
||||
|
||||
## Socket Real-time Service
|
||||
|
||||
Real-time voice output with chunk stream:
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
||||
- [French](#french)
|
||||
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
|
||||
- [German](#german)
|
||||
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
|
||||
- [Hindi](#hindi)
|
||||
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
|
||||
- [Italian](#italian)
|
||||
@@ -97,6 +99,22 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
||||
|
||||
|
||||
## German
|
||||
|
||||
#### F5-TTS Base @ de @ hvoss-techfak
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
|
||||
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
|
||||
|
||||
|
||||
## Hindi
|
||||
|
||||
#### F5-TTS Small @ hi @ SPRINGLab
|
||||
|
||||
@@ -13,8 +13,8 @@ output_file = "infer_cli_story.wav"
|
||||
[voices.town]
|
||||
ref_audio = "infer/examples/multi/town.flac"
|
||||
ref_text = ""
|
||||
speed = 0.8 # will ignore global speed
|
||||
|
||||
[voices.country]
|
||||
ref_audio = "infer/examples/multi/country.flac"
|
||||
ref_text = ""
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace.”
|
||||
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] "My poor dear friend, you live here no better than the ants! Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land." [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] "Goodbye," [main] said he, [country] "I'm off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."
|
||||
@@ -12,6 +12,7 @@ import tomli
|
||||
from cached_path import cached_path
|
||||
from hydra.utils import get_class
|
||||
from omegaconf import OmegaConf
|
||||
from unidecode import unidecode
|
||||
|
||||
from f5_tts.infer.utils_infer import (
|
||||
cfg_strength,
|
||||
@@ -112,6 +113,11 @@ parser.add_argument(
|
||||
action="store_true",
|
||||
help="To save each audio chunks during inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_legacy_text",
|
||||
action="store_false",
|
||||
help="Not to use lossy ASCII transliterations of unicode text in saved file names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_silence",
|
||||
action="store_true",
|
||||
@@ -197,6 +203,12 @@ output_file = args.output_file or config.get(
|
||||
)
|
||||
|
||||
save_chunk = args.save_chunk or config.get("save_chunk", False)
|
||||
use_legacy_text = args.no_legacy_text or config.get("no_legacy_text", False) # no_legacy_text is a store_false arg
|
||||
if save_chunk and use_legacy_text:
|
||||
print(
|
||||
"\nWarning to --save_chunk: lossy ASCII transliterations of unicode text for legacy (.wav) file names, --no_legacy_text to disable.\n"
|
||||
)
|
||||
|
||||
remove_silence = args.remove_silence or config.get("remove_silence", False)
|
||||
load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
|
||||
|
||||
@@ -321,9 +333,10 @@ def main():
|
||||
text = re.sub(reg2, "", text)
|
||||
ref_audio_ = voices[voice]["ref_audio"]
|
||||
ref_text_ = voices[voice]["ref_text"]
|
||||
local_speed = voices[voice].get("speed", speed)
|
||||
gen_text_ = text.strip()
|
||||
print(f"Voice: {voice}")
|
||||
audio_segment, final_sample_rate, spectragram = infer_process(
|
||||
audio_segment, final_sample_rate, spectrogram = infer_process(
|
||||
ref_audio_,
|
||||
ref_text_,
|
||||
gen_text_,
|
||||
@@ -335,7 +348,7 @@ def main():
|
||||
nfe_step=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
speed=speed,
|
||||
speed=local_speed,
|
||||
fix_duration=fix_duration,
|
||||
device=device,
|
||||
)
|
||||
@@ -344,6 +357,8 @@ def main():
|
||||
if save_chunk:
|
||||
if len(gen_text_) > 200:
|
||||
gen_text_ = gen_text_[:200] + " ... "
|
||||
if use_legacy_text:
|
||||
gen_text_ = unidecode(gen_text_)
|
||||
sf.write(
|
||||
os.path.join(output_chunk_dir, f"{len(generated_audio_segments) - 1}_{gen_text_}.wav"),
|
||||
audio_segment,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
@@ -41,6 +42,7 @@ from f5_tts.infer.utils_infer import (
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
save_spectrogram,
|
||||
tempfile_kwargs,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT
|
||||
|
||||
@@ -80,6 +82,8 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
|
||||
vocab_path = str(cached_path(vocab_path))
|
||||
if model_cfg is None:
|
||||
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
||||
elif isinstance(model_cfg, str):
|
||||
model_cfg = json.loads(model_cfg)
|
||||
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
||||
|
||||
|
||||
@@ -124,7 +128,7 @@ def load_text_from_file(file):
|
||||
return gr.update(value=text)
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
|
||||
@gpu_decorator
|
||||
def infer(
|
||||
ref_audio_orig,
|
||||
@@ -163,7 +167,7 @@ def infer(
|
||||
show_info("Loading E2-TTS model...")
|
||||
E2TTS_ema_model = load_e2tts()
|
||||
ema_model = E2TTS_ema_model
|
||||
elif isinstance(model, list) and model[0] == "Custom":
|
||||
elif isinstance(model, tuple) and model[0] == "Custom":
|
||||
assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
|
||||
global custom_ema_model, pre_custom_path
|
||||
if pre_custom_path != model[1]:
|
||||
@@ -187,28 +191,24 @@ def infer(
|
||||
|
||||
# Remove silence
|
||||
if remove_silence:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
sf.write(f.name, final_wave, final_sample_rate)
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
|
||||
temp_path = f.name
|
||||
try:
|
||||
sf.write(temp_path, final_wave, final_sample_rate)
|
||||
remove_silence_for_generated_wav(f.name)
|
||||
final_wave, _ = torchaudio.load(f.name)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
final_wave = final_wave.squeeze().cpu().numpy()
|
||||
|
||||
# Save the spectrogram
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
|
||||
spectrogram_path = tmp_spectrogram.name
|
||||
save_spectrogram(combined_spectrogram, spectrogram_path)
|
||||
save_spectrogram(combined_spectrogram, spectrogram_path)
|
||||
|
||||
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
|
||||
|
||||
|
||||
with gr.Blocks() as app_credits:
|
||||
gr.Markdown("""
|
||||
# Credits
|
||||
|
||||
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
||||
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
||||
""")
|
||||
with gr.Blocks() as app_tts:
|
||||
gr.Markdown("# Batched TTS")
|
||||
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
|
||||
@@ -312,6 +312,12 @@ with gr.Blocks() as app_tts:
|
||||
outputs=[ref_text_input],
|
||||
)
|
||||
|
||||
ref_audio_input.clear(
|
||||
lambda: [None, None],
|
||||
None,
|
||||
[ref_text_input, ref_text_file],
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
basic_tts,
|
||||
inputs=[
|
||||
@@ -357,6 +363,7 @@ def parse_speechtypes_text(gen_text):
|
||||
try: # if type dict
|
||||
current_type_dict = json.loads(type_str)
|
||||
except json.decoder.JSONDecodeError:
|
||||
type_str = type_str[1:-1] # remove brace {}
|
||||
current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
|
||||
|
||||
return segments
|
||||
@@ -923,12 +930,22 @@ Have a conversation with an AI using your reference voice!
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as app_credits:
|
||||
gr.Markdown("""
|
||||
# Credits
|
||||
|
||||
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
|
||||
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
|
||||
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
|
||||
""")
|
||||
|
||||
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown(
|
||||
f"""
|
||||
# E2/F5 TTS
|
||||
# F5-TTS Demo Space
|
||||
|
||||
This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
|
||||
This is {"a local web UI for [F5-TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
|
||||
|
||||
* [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
|
||||
* [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
|
||||
@@ -958,7 +975,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
global tts_model_choice
|
||||
if new_choice == "Custom": # override in case webpage is refreshed
|
||||
custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
|
||||
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
||||
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
|
||||
return (
|
||||
gr.update(visible=True, value=custom_ckpt_path),
|
||||
gr.update(visible=True, value=custom_vocab_path),
|
||||
@@ -970,7 +987,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
|
||||
def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
|
||||
global tts_model_choice
|
||||
tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
|
||||
tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
|
||||
with open(last_used_custom, "w", encoding="utf-8") as f:
|
||||
f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
|
||||
_ref_audio_cache = {}
|
||||
_ref_text_cache = {}
|
||||
|
||||
device = (
|
||||
"cuda"
|
||||
@@ -44,6 +45,8 @@ device = (
|
||||
else "cpu"
|
||||
)
|
||||
|
||||
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
|
||||
|
||||
# -----------------------------------------
|
||||
|
||||
target_sample_rate = 24000
|
||||
@@ -290,62 +293,74 @@ def remove_silence_edges(audio, silence_threshold=-42):
|
||||
# preprocess reference audio and text
|
||||
|
||||
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
|
||||
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
|
||||
show_info("Converting audio...")
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
|
||||
# Compute a hash of the reference audio file
|
||||
with open(ref_audio_orig, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
audio_hash = hashlib.md5(audio_data).hexdigest()
|
||||
|
||||
global _ref_audio_cache
|
||||
|
||||
if audio_hash in _ref_audio_cache:
|
||||
show_info("Using cached preprocessed reference audio...")
|
||||
ref_audio = _ref_audio_cache[audio_hash]
|
||||
|
||||
else: # first pass, do preprocess
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
|
||||
temp_path = f.name
|
||||
|
||||
aseg = AudioSegment.from_file(ref_audio_orig)
|
||||
|
||||
if clip_short:
|
||||
# 1. try to find long silence for clipping
|
||||
# 1. try to find long silence for clipping
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (1)")
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 12s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
aseg = non_silent_wave
|
||||
|
||||
aseg = non_silent_wave
|
||||
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 12s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
aseg.export(f.name, format="wav")
|
||||
ref_audio = f.name
|
||||
aseg.export(temp_path, format="wav")
|
||||
ref_audio = temp_path
|
||||
|
||||
# Compute a hash of the reference audio file
|
||||
with open(ref_audio, "rb") as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
audio_hash = hashlib.md5(audio_data).hexdigest()
|
||||
# Cache the processed reference audio
|
||||
_ref_audio_cache[audio_hash] = ref_audio
|
||||
|
||||
if not ref_text.strip():
|
||||
global _ref_audio_cache
|
||||
if audio_hash in _ref_audio_cache:
|
||||
global _ref_text_cache
|
||||
if audio_hash in _ref_text_cache:
|
||||
# Use cached asr transcription
|
||||
show_info("Using cached reference text...")
|
||||
ref_text = _ref_audio_cache[audio_hash]
|
||||
ref_text = _ref_text_cache[audio_hash]
|
||||
else:
|
||||
show_info("No reference text provided, transcribing reference audio...")
|
||||
ref_text = transcribe(ref_audio)
|
||||
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
|
||||
_ref_audio_cache[audio_hash] = ref_text
|
||||
_ref_text_cache[audio_hash] = ref_text
|
||||
else:
|
||||
show_info("Using custom reference text...")
|
||||
|
||||
@@ -384,7 +399,7 @@ def infer_process(
|
||||
):
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text)
|
||||
|
||||
@@ -29,11 +29,16 @@ from f5_tts.model.modules import (
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
|
||||
def __init__(
|
||||
self, text_num_embeds, text_dim, mask_padding=True, average_upsampling=False, conv_layers=0, conv_mult=2
|
||||
):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
self.mask_padding = mask_padding # mask filler and batch padding tokens or not
|
||||
self.average_upsampling = average_upsampling # zipvoice-style text late average upsampling (after text encoder)
|
||||
if average_upsampling:
|
||||
assert mask_padding, "text_embedding_average_upsampling requires text_mask_padding to be True"
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
@@ -45,11 +50,47 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
|
||||
batch, text_len, text_dim = text.shape
|
||||
|
||||
if audio_mask is None:
|
||||
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
|
||||
valid_mask = audio_mask & text_mask
|
||||
audio_lens = audio_mask.sum(dim=1) # [batch]
|
||||
valid_lens = valid_mask.sum(dim=1) # [batch]
|
||||
|
||||
upsampled_text = torch.zeros_like(text)
|
||||
|
||||
for i in range(batch):
|
||||
audio_len = audio_lens[i].item()
|
||||
valid_len = valid_lens[i].item()
|
||||
|
||||
if valid_len == 0:
|
||||
continue
|
||||
|
||||
valid_ind = torch.where(valid_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
|
||||
|
||||
base_repeat = audio_len // valid_len
|
||||
remainder = audio_len % valid_len
|
||||
|
||||
indices = []
|
||||
for j in range(valid_len):
|
||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||
indices.extend([j] * repeat_count)
|
||||
|
||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||
upsampled = valid_data[indices] # [audio_len, text_dim]
|
||||
|
||||
upsampled_text[i, :audio_len, :] = upsampled
|
||||
|
||||
return upsampled_text
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0) # (opt.) if not self.average_upsampling:
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
@@ -61,7 +102,7 @@ class TextEmbedding(nn.Module):
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
# sinus pos emb
|
||||
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
|
||||
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||
text_pos_embed = self.freqs_cis[pos_idx]
|
||||
text = text + text_pos_embed
|
||||
@@ -75,6 +116,9 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
text = self.text_blocks(text)
|
||||
|
||||
if self.average_upsampling:
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@@ -113,9 +157,12 @@ class DiT(nn.Module):
|
||||
text_num_embeds=256,
|
||||
text_dim=None,
|
||||
text_mask_padding=True,
|
||||
text_embedding_average_upsampling=False,
|
||||
qk_norm=None,
|
||||
conv_layers=0,
|
||||
pe_attn_head=None,
|
||||
attn_backend="torch", # "torch" | "flash_attn"
|
||||
attn_mask_enabled=False,
|
||||
long_skip_connection=False,
|
||||
checkpoint_activations=False,
|
||||
):
|
||||
@@ -125,7 +172,11 @@ class DiT(nn.Module):
|
||||
if text_dim is None:
|
||||
text_dim = mel_dim
|
||||
self.text_embed = TextEmbedding(
|
||||
text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
|
||||
text_num_embeds,
|
||||
text_dim,
|
||||
mask_padding=text_mask_padding,
|
||||
average_upsampling=text_embedding_average_upsampling,
|
||||
conv_layers=conv_layers,
|
||||
)
|
||||
self.text_cond, self.text_uncond = None, None # text cache
|
||||
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||
@@ -145,6 +196,8 @@ class DiT(nn.Module):
|
||||
dropout=dropout,
|
||||
qk_norm=qk_norm,
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
@@ -178,6 +231,33 @@ class DiT(nn.Module):
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def get_input_embed(
|
||||
self,
|
||||
x, # b n d
|
||||
cond, # b n d
|
||||
text, # b nt
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||
):
|
||||
seq_len = x.shape[1]
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
return x
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
@@ -187,10 +267,11 @@ class DiT(nn.Module):
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
cache: bool = False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
@@ -198,18 +279,20 @@ class DiT(nn.Module):
|
||||
|
||||
# t: conditioning time, text: text, x: noised audio + cond audio + text
|
||||
t = self.time_embed(time)
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False)
|
||||
text_embed = self.text_cond
|
||||
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
||||
x_cond = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache, audio_mask=mask
|
||||
)
|
||||
x_uncond = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache, audio_mask=mask
|
||||
)
|
||||
x = torch.cat((x_cond, x_uncond), dim=0)
|
||||
t = torch.cat((t, t), dim=0)
|
||||
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
x = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache, audio_mask=mask
|
||||
)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
|
||||
@@ -141,26 +141,15 @@ class MMDiT(nn.Module):
|
||||
nn.init.constant_(self.proj_out.weight, 0)
|
||||
nn.init.constant_(self.proj_out.bias, 0)
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
def get_input_embed(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
x, # b n d
|
||||
cond, # b n d
|
||||
text, # b nt
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
):
|
||||
batch = x.shape[0]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
@@ -174,6 +163,41 @@ class MMDiT(nn.Module):
|
||||
c = self.text_embed(text, drop_text=drop_text)
|
||||
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
return x, c
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
cache: bool = False,
|
||||
):
|
||||
batch = x.shape[0]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
||||
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
|
||||
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
|
||||
x = torch.cat((x_cond, x_uncond), dim=0)
|
||||
c = torch.cat((c_cond, c_uncond), dim=0)
|
||||
t = torch.cat((t, t), dim=0)
|
||||
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
||||
else:
|
||||
x, c = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
|
||||
)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
text_len = text.shape[1]
|
||||
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
@@ -178,26 +178,16 @@ class UNetT(nn.Module):
|
||||
self.norm_out = RMSNorm(dim)
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
def get_input_embed(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
drop_audio_cond, # cfg for cond audio
|
||||
drop_text, # cfg for text
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
cache=False,
|
||||
x, # b n d
|
||||
cond, # b n d
|
||||
text, # b nt
|
||||
drop_audio_cond: bool = False,
|
||||
drop_text: bool = False,
|
||||
cache: bool = True,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
seq_len = x.shape[1]
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
@@ -209,8 +199,41 @@ class UNetT(nn.Module):
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
return x
|
||||
|
||||
def clear_cache(self):
|
||||
self.text_cond, self.text_uncond = None, None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||
text: int["b nt"], # text # noqa: F722
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
drop_audio_cond: bool = False, # cfg for cond audio
|
||||
drop_text: bool = False, # cfg for text
|
||||
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward
|
||||
cache: bool = False,
|
||||
):
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
time = time.repeat(batch)
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
||||
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
|
||||
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
|
||||
x = torch.cat((x_cond, x_uncond), dim=0)
|
||||
t = torch.cat((t, t), dim=0)
|
||||
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
||||
else:
|
||||
x = self.get_input_embed(x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache)
|
||||
|
||||
# postfix time t to input x, [b n d] -> [b n+1 d]
|
||||
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
||||
if mask is not None:
|
||||
|
||||
@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
|
||||
from f5_tts.model.utils import (
|
||||
default,
|
||||
exists,
|
||||
get_epss_timesteps,
|
||||
lens_to_mask,
|
||||
list_str_to_idx,
|
||||
list_str_to_tensor,
|
||||
@@ -92,6 +93,7 @@ class CFM(nn.Module):
|
||||
seed: int | None = None,
|
||||
max_duration=4096,
|
||||
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
||||
use_epss=True,
|
||||
no_ref_audio=False,
|
||||
duplicate_test=False,
|
||||
t_inter=0.1,
|
||||
@@ -160,16 +162,31 @@ class CFM(nn.Module):
|
||||
# at each step, conditioning is fixed
|
||||
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
||||
|
||||
# predict flow
|
||||
pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=False, drop_text=False, cache=True
|
||||
)
|
||||
# predict flow (cond)
|
||||
if cfg_strength < 1e-5:
|
||||
pred = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
drop_audio_cond=False,
|
||||
drop_text=False,
|
||||
cache=True,
|
||||
)
|
||||
return pred
|
||||
|
||||
null_pred = self.transformer(
|
||||
x=x, cond=step_cond, text=text, time=t, mask=mask, drop_audio_cond=True, drop_text=True, cache=True
|
||||
# predict flow (cond and uncond), for classifier-free guidance
|
||||
pred_cfg = self.transformer(
|
||||
x=x,
|
||||
cond=step_cond,
|
||||
text=text,
|
||||
time=t,
|
||||
mask=mask,
|
||||
cfg_infer=True,
|
||||
cache=True,
|
||||
)
|
||||
pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
|
||||
return pred + (pred - null_pred) * cfg_strength
|
||||
|
||||
# noise input
|
||||
@@ -190,7 +207,10 @@ class CFM(nn.Module):
|
||||
y0 = (1 - t_start) * y0 + t_start * test_cond
|
||||
steps = int(steps * (1 - t_start))
|
||||
|
||||
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
||||
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
|
||||
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
||||
else:
|
||||
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
|
||||
if sway_sampling_coef is not None:
|
||||
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||
|
||||
@@ -270,10 +290,9 @@ class CFM(nn.Module):
|
||||
else:
|
||||
drop_text = False
|
||||
|
||||
# if want rigorously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
||||
# apply mask will use more memory; might adjust batchsize or batchsampler long sequence threshold
|
||||
pred = self.transformer(
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text
|
||||
x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, mask=mask
|
||||
)
|
||||
|
||||
# flow matching loss
|
||||
|
||||
@@ -312,7 +312,7 @@ def collate_fn(batch):
|
||||
max_mel_length = mel_lengths.amax()
|
||||
|
||||
padded_mel_specs = []
|
||||
for spec in mel_specs: # TODO. maybe records mask for attention here
|
||||
for spec in mel_specs:
|
||||
padding = (0, max_mel_length - spec.size(-1))
|
||||
padded_spec = F.pad(spec, padding, value=0)
|
||||
padded_mel_specs.append(padded_spec)
|
||||
@@ -324,7 +324,7 @@ def collate_fn(batch):
|
||||
|
||||
return dict(
|
||||
mel=mel_specs,
|
||||
mel_lengths=mel_lengths,
|
||||
mel_lengths=mel_lengths, # records for padding mask
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
# flake8: noqa
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -19,6 +20,8 @@ from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
from f5_tts.model.utils import is_package_available
|
||||
|
||||
|
||||
# raw wav to mel spec
|
||||
|
||||
@@ -175,7 +178,7 @@ class ConvPositionEmbedding(nn.Module):
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
@@ -417,9 +420,9 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b n d"] = None, # context c # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
c: float["b n d"] = None, # context c
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
@@ -431,19 +434,30 @@ class Attention(nn.Module):
|
||||
|
||||
# Attention processor
|
||||
|
||||
if is_package_available("flash_attn"):
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
pe_attn_head: int | None = None, # number of attention head to apply rope, None for all
|
||||
attn_backend: str = "torch", # "torch" or "flash_attn"
|
||||
attn_mask_enabled: bool = True,
|
||||
):
|
||||
if attn_backend == "flash_attn":
|
||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||
|
||||
self.pe_attn_head = pe_attn_head
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_mask_enabled = attn_mask_enabled
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = x.shape[0]
|
||||
@@ -479,16 +493,40 @@ class AttnProcessor:
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
if self.attn_backend == "torch":
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
elif self.attn_backend == "flash_attn":
|
||||
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
|
||||
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
||||
value, _, _, _, _ = unpad_input(value, mask)
|
||||
x = flash_attn_varlen_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_cu_seqlens,
|
||||
k_cu_seqlens,
|
||||
q_max_seqlen_in_batch,
|
||||
k_max_seqlen_in_batch,
|
||||
)
|
||||
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
else:
|
||||
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
@@ -514,9 +552,9 @@ class JointAttnProcessor:
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
x: float["b n d"], # noised input x
|
||||
c: float["b nt d"] = None, # context c, here text
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.FloatTensor:
|
||||
@@ -608,12 +646,27 @@ class JointAttnProcessor:
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads,
|
||||
dim_head,
|
||||
ff_mult=4,
|
||||
dropout=0.1,
|
||||
qk_norm=None,
|
||||
pe_attn_head=None,
|
||||
attn_backend="torch", # "torch" or "flash_attn"
|
||||
attn_mask_enabled=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(pe_attn_head=pe_attn_head),
|
||||
processor=AttnProcessor(
|
||||
pe_attn_head=pe_attn_head,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
@@ -724,7 +777,7 @@ class TimestepEmbedding(nn.Module):
|
||||
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: float["b"]): # noqa: F821
|
||||
def forward(self, timestep: float["b"]):
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
|
||||
@@ -149,7 +149,7 @@ class Trainer:
|
||||
if self.is_main:
|
||||
checkpoint = dict(
|
||||
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
|
||||
optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
||||
optimizer_state_dict=self.optimizer.state_dict(),
|
||||
ema_model_state_dict=self.ema_model.state_dict(),
|
||||
scheduler_state_dict=self.scheduler.state_dict(),
|
||||
update=update,
|
||||
@@ -242,7 +242,7 @@ class Trainer:
|
||||
del checkpoint["model_state_dict"][key]
|
||||
|
||||
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
|
||||
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
if self.scheduler:
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
update = checkpoint["update"]
|
||||
|
||||
@@ -35,6 +35,16 @@ def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
def is_package_available(package_name: str) -> bool:
|
||||
try:
|
||||
import importlib
|
||||
|
||||
package_exists = importlib.util.find_spec(package_name) is not None
|
||||
return package_exists
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# tensor helpers
|
||||
|
||||
|
||||
@@ -189,3 +199,22 @@ def repetition_found(text, length=2, tolerance=10):
|
||||
if count > tolerance:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# get the empirically pruned step for sampling
|
||||
|
||||
|
||||
def get_epss_timesteps(n, device, dtype):
|
||||
dt = 1 / 32
|
||||
predefined_timesteps = {
|
||||
5: [0, 2, 4, 8, 16, 32],
|
||||
6: [0, 2, 4, 6, 8, 16, 32],
|
||||
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
||||
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
||||
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
||||
}
|
||||
t = predefined_timesteps.get(n, [])
|
||||
if not t:
|
||||
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
|
||||
return dt * torch.tensor(t, device=device, dtype=dtype)
|
||||
|
||||
@@ -220,8 +220,8 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
waveform = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
@@ -244,7 +244,7 @@ async def send(
|
||||
model_name: str,
|
||||
padding_duration: int = None,
|
||||
audio_save_dir: str = "./",
|
||||
save_sample_rate: int = 16000,
|
||||
save_sample_rate: int = 24000,
|
||||
):
|
||||
total_duration = 0.0
|
||||
latency_data = []
|
||||
@@ -254,7 +254,7 @@ async def send(
|
||||
for i, item in enumerate(manifest_item_list):
|
||||
if i % log_interval == 0:
|
||||
print(f"{name}: {i}/{len(manifest_item_list)}")
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
|
||||
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
|
||||
duration = len(waveform) / sample_rate
|
||||
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||
|
||||
@@ -310,8 +310,9 @@ async def send(
|
||||
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
|
||||
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
|
||||
|
||||
latency_data.append((end, estimated_target_duration))
|
||||
total_duration += estimated_target_duration
|
||||
actual_duration = len(audio) / save_sample_rate
|
||||
latency_data.append((end, actual_duration))
|
||||
total_duration += actual_duration
|
||||
|
||||
return total_duration, latency_data
|
||||
|
||||
@@ -416,7 +417,7 @@ async def main():
|
||||
model_name=args.model_name,
|
||||
audio_save_dir=args.log_dir,
|
||||
padding_duration=1,
|
||||
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
|
||||
save_sample_rate=24000,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
@@ -82,7 +82,7 @@ def prepare_request(
|
||||
samples,
|
||||
reference_text,
|
||||
target_text,
|
||||
sample_rate=16000,
|
||||
sample_rate=24000,
|
||||
audio_save_dir: str = "./",
|
||||
):
|
||||
assert len(samples.shape) == 1, "samples should be 1D"
|
||||
@@ -106,8 +106,8 @@ def prepare_request(
|
||||
return data
|
||||
|
||||
|
||||
def load_audio(wav_path, target_sample_rate=16000):
|
||||
assert target_sample_rate == 16000, "hard coding in server"
|
||||
def load_audio(wav_path, target_sample_rate=24000):
|
||||
assert target_sample_rate == 24000, "hard coding in server"
|
||||
if isinstance(wav_path, dict):
|
||||
samples = wav_path["array"]
|
||||
sample_rate = wav_path["sampling_rate"]
|
||||
@@ -129,7 +129,7 @@ if __name__ == "__main__":
|
||||
|
||||
url = f"{server_url}/v2/models/{args.model_name}/infer"
|
||||
samples, sr = load_audio(args.reference_audio)
|
||||
assert sr == 16000, "sample rate hardcoded in server"
|
||||
assert sr == 24000, "sample rate hardcoded in server"
|
||||
|
||||
samples = np.array(samples, dtype=np.float32)
|
||||
data = prepare_request(samples, args.reference_text, args.target_text)
|
||||
|
||||
@@ -33,7 +33,7 @@ parameters [
|
||||
},
|
||||
{
|
||||
key: "reference_audio_sample_rate",
|
||||
value: {string_value:"16000"}
|
||||
value: {string_value:"24000"}
|
||||
},
|
||||
{
|
||||
key: "vocoder",
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
# Training
|
||||
|
||||
Check your FFmpeg installation:
|
||||
```bash
|
||||
ffmpeg -version
|
||||
```
|
||||
If not found, install it first (or skip assuming you know of other backends available).
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
|
||||
|
||||
@@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
print(f"\nSaving to {out_dir} ...")
|
||||
|
||||
# Save dataset with improved batch size for better I/O performance
|
||||
raw_arrow_path = out_dir / "raw.arrow"
|
||||
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
|
||||
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# Save durations to JSON
|
||||
dur_json_path = out_dir / "duration.json"
|
||||
|
||||
@@ -181,6 +181,7 @@ def main():
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -68,6 +68,7 @@ def main():
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
|
||||
@@ -62,6 +62,7 @@ def main():
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -39,6 +39,7 @@ def main():
|
||||
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
||||
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
||||
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -178,50 +178,12 @@ def get_audio_duration(audio_path):
|
||||
return audio.shape[1] / sample_rate
|
||||
|
||||
|
||||
def clear_text(text):
|
||||
"""Clean and prepare text by lowering the case and stripping whitespace."""
|
||||
return text.lower().strip()
|
||||
|
||||
|
||||
def get_rms(
|
||||
y,
|
||||
frame_length=2048,
|
||||
hop_length=512,
|
||||
pad_mode="constant",
|
||||
): # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
|
||||
padding = (int(frame_length // 2), int(frame_length // 2))
|
||||
y = np.pad(y, padding, mode=pad_mode)
|
||||
|
||||
axis = -1
|
||||
# put our new within-frame axis at the end for now
|
||||
out_strides = y.strides + tuple([y.strides[axis]])
|
||||
# Reduce the shape on the framing axis
|
||||
x_shape_trimmed = list(y.shape)
|
||||
x_shape_trimmed[axis] -= frame_length - 1
|
||||
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
|
||||
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
|
||||
if axis < 0:
|
||||
target_axis = axis - 1
|
||||
else:
|
||||
target_axis = axis + 1
|
||||
xw = np.moveaxis(xw, -1, target_axis)
|
||||
# Downsample along the target axis
|
||||
slices = [slice(None)] * xw.ndim
|
||||
slices[axis] = slice(0, None, hop_length)
|
||||
x = xw[tuple(slices)]
|
||||
|
||||
# Calculate power
|
||||
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
|
||||
|
||||
return np.sqrt(power)
|
||||
|
||||
|
||||
class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
|
||||
def __init__(
|
||||
self,
|
||||
sr: int,
|
||||
threshold: float = -40.0,
|
||||
min_length: int = 2000,
|
||||
min_length: int = 20000, # 20 seconds
|
||||
min_interval: int = 300,
|
||||
hop_size: int = 20,
|
||||
max_sil_kept: int = 2000,
|
||||
@@ -252,7 +214,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
|
||||
samples = waveform
|
||||
if samples.shape[0] <= self.min_length:
|
||||
return [waveform]
|
||||
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
sil_tags = []
|
||||
silence_start = None
|
||||
clip_start = 0
|
||||
@@ -306,8 +268,7 @@ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
# Apply and return slices.
|
||||
####音频+起始时间+终止时间
|
||||
# Apply and return slices: [chunk, start, end]
|
||||
if len(sil_tags) == 0:
|
||||
return [[waveform, 0, int(total_frames * self.hop_size)]]
|
||||
else:
|
||||
@@ -434,7 +395,7 @@ def start_training(
|
||||
fp16 = ""
|
||||
|
||||
cmd = (
|
||||
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
|
||||
f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
|
||||
f" --learning_rate {learning_rate}"
|
||||
f" --batch_size_per_gpu {batch_size_per_gpu}"
|
||||
f" --batch_size_type {batch_size_type}"
|
||||
@@ -453,7 +414,7 @@ def start_training(
|
||||
cmd += " --finetune"
|
||||
|
||||
if file_checkpoint_train != "":
|
||||
cmd += f" --pretrain {file_checkpoint_train}"
|
||||
cmd += f' --pretrain "{file_checkpoint_train}"'
|
||||
|
||||
if tokenizer_file != "":
|
||||
cmd += f" --tokenizer_path {tokenizer_file}"
|
||||
@@ -707,7 +668,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
|
||||
|
||||
try:
|
||||
text = transcribe(file_segment, language)
|
||||
text = text.lower().strip().replace('"', "")
|
||||
text = text.strip()
|
||||
|
||||
data += f"{name_segment}|{text}\n"
|
||||
|
||||
@@ -816,7 +777,7 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
||||
error_files.append([file_audio, "very short text length 3"])
|
||||
continue
|
||||
|
||||
text = clear_text(text)
|
||||
text = text.strip()
|
||||
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
||||
|
||||
audio_path_list.append(file_audio)
|
||||
@@ -835,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
||||
min_second = round(min(duration_list), 2)
|
||||
max_second = round(max(duration_list), 2)
|
||||
|
||||
with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
|
||||
with ArrowWriter(path=file_raw) as writer:
|
||||
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
||||
writer.write(line)
|
||||
writer.finalize()
|
||||
|
||||
with open(file_duration, "w") as f:
|
||||
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
||||
@@ -1099,7 +1061,7 @@ def vocab_extend(project_name, symbols, model_type):
|
||||
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
|
||||
|
||||
|
||||
def vocab_check(project_name):
|
||||
def vocab_check(project_name, tokenizer_type):
|
||||
name_project = project_name
|
||||
path_project = os.path.join(path_data, name_project)
|
||||
|
||||
@@ -1127,7 +1089,9 @@ def vocab_check(project_name):
|
||||
if len(sp) != 2:
|
||||
continue
|
||||
|
||||
text = sp[1].lower().strip()
|
||||
text = sp[1].strip()
|
||||
if tokenizer_type == "pinyin":
|
||||
text = convert_char_to_pinyin([text], polyphone=True)[0]
|
||||
|
||||
for t in text:
|
||||
if t not in vocab and t not in miss_symbols_keep:
|
||||
@@ -1232,8 +1196,8 @@ def infer(
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
||||
tts_api.infer(
|
||||
ref_file=ref_audio,
|
||||
ref_text=ref_text.lower().strip(),
|
||||
gen_text=gen_text.lower().strip(),
|
||||
ref_text=ref_text.strip(),
|
||||
gen_text=gen_text.strip(),
|
||||
nfe_step=nfe_step,
|
||||
speed=speed,
|
||||
remove_silence=remove_silence,
|
||||
@@ -1498,7 +1462,9 @@ Using the extended model, you can finetune to a new language that is missing sym
|
||||
txt_info_extend = gr.Textbox(label="Info", value="")
|
||||
|
||||
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
|
||||
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
|
||||
check_button.click(
|
||||
fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
|
||||
)
|
||||
extend_button.click(
|
||||
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user