mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-25 12:24:54 -08:00
1.0.0 F5-TTS v1 base model with better training and inference performance
This commit is contained in:
@@ -68,14 +68,16 @@ Basically you can inference with flags:
|
||||
```bash
|
||||
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
||||
f5-tts_infer-cli \
|
||||
--model "F5-TTS" \
|
||||
--model F5TTS_v1_Base \
|
||||
--ref_audio "ref_audio.wav" \
|
||||
--ref_text "The content, subtitle or transcription of reference audio." \
|
||||
--gen_text "Some text you want TTS model generate for you."
|
||||
|
||||
# Choose Vocoder
|
||||
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
||||
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
||||
# Use BigVGAN as vocoder. Currently only support F5TTS_Base.
|
||||
f5-tts_infer-cli --model F5TTS_Base --vocoder_name bigvgan --load_vocoder_from_local
|
||||
|
||||
# Use custom path checkpoint, e.g.
|
||||
f5-tts_infer-cli --ckpt_file ckpts/F5TTS_Base/model_1200000.safetensors
|
||||
|
||||
# More instructions
|
||||
f5-tts_infer-cli --help
|
||||
@@ -90,8 +92,8 @@ f5-tts_infer-cli -c custom.toml
|
||||
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
||||
|
||||
```toml
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
@@ -105,8 +107,8 @@ output_dir = "tests"
|
||||
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
||||
|
||||
```toml
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/multi/main.flac"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = ""
|
||||
@@ -126,6 +128,22 @@ ref_text = ""
|
||||
```
|
||||
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
||||
|
||||
## Socket Real-time Service
|
||||
|
||||
Real-time voice output with chunk stream:
|
||||
|
||||
```bash
|
||||
# Start socket server
|
||||
python src/f5_tts/socket_server.py
|
||||
|
||||
# If PyAudio not installed
|
||||
sudo apt-get install portaudio19-dev
|
||||
pip install pyaudio
|
||||
|
||||
# Communicate with socket client
|
||||
python src/f5_tts/socket_client.py
|
||||
```
|
||||
|
||||
## Speech Editing
|
||||
|
||||
To test speech editing capabilities, use the following command:
|
||||
@@ -134,86 +152,3 @@ To test speech editing capabilities, use the following command:
|
||||
python src/f5_tts/infer/speech_edit.py
|
||||
```
|
||||
|
||||
## Socket Realtime Client
|
||||
|
||||
To communicate with socket server you need to run
|
||||
```bash
|
||||
python src/f5_tts/socket_server.py
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Then create client to communicate</summary>
|
||||
|
||||
```bash
|
||||
# If PyAudio not installed
|
||||
sudo apt-get install portaudio19-dev
|
||||
pip install pyaudio
|
||||
```
|
||||
|
||||
``` python
|
||||
# Create the socket_client.py
|
||||
import socket
|
||||
import asyncio
|
||||
import pyaudio
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def listen_to_F5TTS(text, server_ip="localhost", server_port=9998):
|
||||
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.connect, (server_ip, int(server_port)))
|
||||
|
||||
start_time = time.time()
|
||||
first_chunk_time = None
|
||||
|
||||
async def play_audio_stream():
|
||||
nonlocal first_chunk_time
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=24000, output=True, frames_per_buffer=2048)
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 8192)
|
||||
if not data:
|
||||
break
|
||||
if data == b"END":
|
||||
logger.info("End of audio received.")
|
||||
break
|
||||
|
||||
audio_array = np.frombuffer(data, dtype=np.float32)
|
||||
stream.write(audio_array.tobytes())
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time()
|
||||
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
logger.info(f"Total time taken: {time.time() - start_time:.4f} seconds")
|
||||
|
||||
try:
|
||||
data_to_send = f"{text}".encode("utf-8")
|
||||
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, data_to_send)
|
||||
await play_audio_stream()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen_to_F5TTS: {e}")
|
||||
|
||||
finally:
|
||||
client_socket.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_to_send = "As a Reader assistant, I'm familiar with new technology. which are key to its improved performance in terms of both training speed and inference efficiency. Let's break down the components"
|
||||
|
||||
asyncio.run(listen_to_F5TTS(text_to_send))
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
<!-- omit in toc -->
|
||||
### Supported Languages
|
||||
- [Multilingual](#multilingual)
|
||||
- [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
|
||||
- [F5-TTS v1 v0 Base @ zh \& en @ F5-TTS](#f5-tts-v1-v0-base--zh--en--f5-tts)
|
||||
- [English](#english)
|
||||
- [Finnish](#finnish)
|
||||
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
|
||||
@@ -37,7 +37,17 @@
|
||||
|
||||
## Multilingual
|
||||
|
||||
#### F5-TTS Base @ zh & en @ F5-TTS
|
||||
#### F5-TTS v1 v0 Base @ zh & en @ F5-TTS
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS v1 Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_v1_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
||||
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
```
|
||||
|
||||
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
||||
|:---:|:------------:|:-----------:|:-------------:|
|
||||
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
|
||||
@@ -45,7 +55,7 @@
|
||||
```bash
|
||||
Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
|
||||
Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
*Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
|
||||
@@ -64,7 +74,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
|
||||
Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
|
||||
@@ -78,7 +88,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
|
||||
Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
|
||||
@@ -96,7 +106,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
||||
Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
||||
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Authors: SPRING Lab, Indian Institute of Technology, Madras
|
||||
@@ -113,7 +123,7 @@ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "c
|
||||
```bash
|
||||
Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
|
||||
Vocab: hf://alien79/F5-TTS-italian/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
- Trained by [Mithril Man](https://github.com/MithrilMan)
|
||||
@@ -131,7 +141,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://Jmica/F5TTS/JA_25498980/model_25498980.pt
|
||||
Vocab: hf://Jmica/F5TTS/JA_25498980/vocab_updated.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
|
||||
|
||||
@@ -148,7 +158,7 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
|
||||
```bash
|
||||
Model: hf://hotstone228/F5-TTS-Russian/model_last.safetensors
|
||||
Vocab: hf://hotstone228/F5-TTS-Russian/vocab.txt
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
|
||||
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
|
||||
```
|
||||
- Finetuned by [HotDro4illa](https://github.com/HotDro4illa)
|
||||
- Any improvements are welcome
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = "Some call me nature, others call me mother nature."
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# F5-TTS | E2-TTS
|
||||
model = "F5-TTS"
|
||||
# F5TTS_v1_Base | E2TTS_Base
|
||||
model = "F5TTS_v1_Base"
|
||||
ref_audio = "infer/examples/multi/main.flac"
|
||||
# If an empty "", transcribes the reference audio automatically.
|
||||
ref_text = ""
|
||||
|
||||
@@ -27,7 +27,7 @@ from f5_tts.infer.utils_infer import (
|
||||
preprocess_ref_audio_text,
|
||||
remove_silence_for_generated_wav,
|
||||
)
|
||||
from f5_tts.model import DiT, UNetT
|
||||
from f5_tts.model import DiT, UNetT # noqa: F401. used for config
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -50,7 +50,7 @@ parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
type=str,
|
||||
help="The model name: F5-TTS | E2-TTS",
|
||||
help="The model name: F5TTS_v1_Base | F5TTS_Base | E2TTS_Base | etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-mc",
|
||||
@@ -172,8 +172,7 @@ config = tomli.load(open(args.config, "rb"))
|
||||
|
||||
# command-line interface parameters
|
||||
|
||||
model = args.model or config.get("model", "F5-TTS")
|
||||
model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
|
||||
model = args.model or config.get("model", "F5TTS_v1_Base")
|
||||
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
||||
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
||||
|
||||
@@ -245,36 +244,32 @@ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_loc
|
||||
|
||||
# load TTS model
|
||||
|
||||
if model == "F5-TTS":
|
||||
model_cls = DiT
|
||||
model_cfg = OmegaConf.load(model_cfg).model.arch
|
||||
if not ckpt_file: # path not specified, download from repo
|
||||
if vocoder_name == "vocos":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base"
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
elif vocoder_name == "bigvgan":
|
||||
repo_name = "F5-TTS"
|
||||
exp_name = "F5TTS_Base_bigvgan"
|
||||
ckpt_step = 1250000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
||||
model_cfg = OmegaConf.load(
|
||||
args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
|
||||
).model
|
||||
model_cls = globals()[model_cfg.backbone]
|
||||
|
||||
elif model == "E2-TTS":
|
||||
assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
|
||||
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
if not ckpt_file: # path not specified, download from repo
|
||||
repo_name = "E2-TTS"
|
||||
exp_name = "E2TTS_Base"
|
||||
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
|
||||
|
||||
if model != "F5TTS_Base":
|
||||
assert vocoder_name == model_cfg.mel_spec.mel_spec_type
|
||||
|
||||
# override for previous models
|
||||
if model == "F5TTS_Base":
|
||||
if vocoder_name == "vocos":
|
||||
ckpt_step = 1200000
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
||||
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
||||
elif vocoder_name == "bigvgan":
|
||||
model = "F5TTS_Base_bigvgan"
|
||||
ckpt_type = "pt"
|
||||
elif model == "E2TTS_Base":
|
||||
repo_name = "E2-TTS"
|
||||
ckpt_step = 1200000
|
||||
|
||||
if not ckpt_file:
|
||||
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}"))
|
||||
|
||||
print(f"Using {model}...")
|
||||
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
||||
ema_model = load_model(model_cls, model_cfg.arch, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
||||
|
||||
|
||||
# inference process
|
||||
|
||||
@@ -41,12 +41,12 @@ from f5_tts.infer.utils_infer import (
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TTS_MODEL = "F5-TTS"
|
||||
DEFAULT_TTS_MODEL = "F5-TTS_v1"
|
||||
tts_model_choice = DEFAULT_TTS_MODEL
|
||||
|
||||
DEFAULT_TTS_MODEL_CFG = [
|
||||
"hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
|
||||
"hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
|
||||
"hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
|
||||
"hf://SWivid/F5-TTS/F5TTS_v1_Base/vocab.txt",
|
||||
json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
|
||||
]
|
||||
|
||||
@@ -56,13 +56,15 @@ DEFAULT_TTS_MODEL_CFG = [
|
||||
vocoder = load_vocoder()
|
||||
|
||||
|
||||
def load_f5tts(ckpt_path=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))):
|
||||
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
def load_f5tts():
|
||||
ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
|
||||
F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
||||
return load_model(DiT, F5TTS_model_cfg, ckpt_path)
|
||||
|
||||
|
||||
def load_e2tts(ckpt_path=str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))):
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
def load_e2tts():
|
||||
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
|
||||
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
|
||||
return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
|
||||
|
||||
|
||||
@@ -73,7 +75,7 @@ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
|
||||
if vocab_path.startswith("hf://"):
|
||||
vocab_path = str(cached_path(vocab_path))
|
||||
if model_cfg is None:
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
|
||||
return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
|
||||
|
||||
|
||||
@@ -130,7 +132,7 @@ def infer(
|
||||
|
||||
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
|
||||
|
||||
if model == "F5-TTS":
|
||||
if model == DEFAULT_TTS_MODEL:
|
||||
ema_model = F5TTS_ema_model
|
||||
elif model == "E2-TTS":
|
||||
global E2TTS_ema_model
|
||||
@@ -762,7 +764,7 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
"""
|
||||
)
|
||||
|
||||
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
|
||||
last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
|
||||
|
||||
def load_last_used_custom():
|
||||
try:
|
||||
@@ -821,7 +823,30 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
|
||||
custom_model_cfg = gr.Dropdown(
|
||||
choices=[
|
||||
DEFAULT_TTS_MODEL_CFG[2],
|
||||
json.dumps(dict(dim=768, depth=18, heads=12, ff_mult=2, text_dim=512, conv_layers=4)),
|
||||
json.dumps(
|
||||
dict(
|
||||
dim=1024,
|
||||
depth=22,
|
||||
heads=16,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
text_mask_padding=False,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
),
|
||||
json.dumps(
|
||||
dict(
|
||||
dim=768,
|
||||
depth=18,
|
||||
heads=12,
|
||||
ff_mult=2,
|
||||
text_dim=512,
|
||||
text_mask_padding=False,
|
||||
conv_layers=4,
|
||||
pe_attn_head=1,
|
||||
)
|
||||
),
|
||||
],
|
||||
value=load_last_used_custom()[2],
|
||||
allow_custom_value=True,
|
||||
|
||||
@@ -2,12 +2,15 @@ import os
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
||||
|
||||
from importlib.resources import files
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
|
||||
from f5_tts.model import CFM, DiT, UNetT
|
||||
from f5_tts.model import CFM, DiT, UNetT # noqa: F401. used for config
|
||||
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
|
||||
|
||||
device = (
|
||||
@@ -21,44 +24,40 @@ device = (
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
||||
target_rms = 0.1
|
||||
|
||||
tokenizer = "pinyin"
|
||||
dataset_name = "Emilia_ZH_EN"
|
||||
|
||||
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
|
||||
seed = None # int | None
|
||||
|
||||
exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
|
||||
ckpt_step = 1200000
|
||||
exp_name = "F5TTS_v1_Base" # F5TTS_v1_Base | E2TTS_Base
|
||||
ckpt_step = 1250000
|
||||
|
||||
nfe_step = 32 # 16, 32
|
||||
cfg_strength = 2.0
|
||||
ode_method = "euler" # euler | midpoint
|
||||
sway_sampling_coef = -1.0
|
||||
speed = 1.0
|
||||
target_rms = 0.1
|
||||
|
||||
if exp_name == "F5TTS_Base":
|
||||
model_cls = DiT
|
||||
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
||||
|
||||
elif exp_name == "E2TTS_Base":
|
||||
model_cls = UNetT
|
||||
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
||||
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{exp_name}.yaml")))
|
||||
model_cls = globals()[model_cfg.model.backbone]
|
||||
model_arc = model_cfg.model.arch
|
||||
|
||||
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
dataset_name = model_cfg.datasets.name
|
||||
tokenizer = model_cfg.model.tokenizer
|
||||
|
||||
mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
|
||||
target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
|
||||
n_mel_channels = model_cfg.model.mel_spec.n_mel_channels
|
||||
hop_length = model_cfg.model.mel_spec.hop_length
|
||||
win_length = model_cfg.model.mel_spec.win_length
|
||||
n_fft = model_cfg.model.mel_spec.n_fft
|
||||
|
||||
|
||||
ckpt_path = str(files("f5_tts").joinpath("../../")) + f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
|
||||
output_dir = "tests"
|
||||
|
||||
|
||||
# [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
|
||||
# pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
|
||||
# [write the origin_text into a file, e.g. tests/test_edit.txt]
|
||||
@@ -67,7 +66,7 @@ output_dir = "tests"
|
||||
# [--language "zho" for Chinese, "eng" for English]
|
||||
# [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
|
||||
|
||||
audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
|
||||
audio_to_edit = str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav"))
|
||||
origin_text = "Some call me nature, others call me mother nature."
|
||||
target_text = "Some call me optimist, others call me realist."
|
||||
parts_to_edit = [
|
||||
@@ -106,7 +105,7 @@ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
transformer=model_cls(**model_arc, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
|
||||
@@ -301,19 +301,19 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 15s, clipping short. (1)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
|
||||
# 2. try to find short silence for clipping if 1. failed
|
||||
if len(non_silent_wave) > 15000:
|
||||
if len(non_silent_wave) > 12000:
|
||||
non_silent_segs = silence.split_on_silence(
|
||||
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
||||
)
|
||||
non_silent_wave = AudioSegment.silent(duration=0)
|
||||
for non_silent_seg in non_silent_segs:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
||||
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
|
||||
show_info("Audio is over 15s, clipping short. (2)")
|
||||
break
|
||||
non_silent_wave += non_silent_seg
|
||||
@@ -321,8 +321,8 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_in
|
||||
aseg = non_silent_wave
|
||||
|
||||
# 3. if no proper silence found for clipping
|
||||
if len(aseg) > 15000:
|
||||
aseg = aseg[:15000]
|
||||
if len(aseg) > 12000:
|
||||
aseg = aseg[:12000]
|
||||
show_info("Audio is over 15s, clipping short. (3)")
|
||||
|
||||
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
||||
@@ -383,7 +383,7 @@ def infer_process(
|
||||
):
|
||||
# Split the input text into batches
|
||||
audio, sr = torchaudio.load(ref_audio)
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
||||
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
|
||||
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
||||
for i, gen_text in enumerate(gen_text_batches):
|
||||
print(f"gen_text {i}", gen_text)
|
||||
|
||||
Reference in New Issue
Block a user