19 Commits
1.1.4 ... 1.1.5

Author SHA1 Message Date
SWivid
646f34b20f v1.1.5 pypi 2025-06-06 07:08:59 +08:00
Jerrister Zheng
2e2acc6ea2 Update: Empirically Pruned Step Sampling (#1077)
* update Empirically Pruned Step Sampling

---------

Co-authored-by: Fast-F5-TTS <2942755472@qq.com>
Co-authored-by: SWivid <swivid@qq.com>
2025-06-04 22:59:30 +08:00
SWivid
6fbe7592f5 rebase default sample_rate to 24khz for runtime 2025-06-04 11:22:31 +08:00
Alice Yanagi
7e37bc5d9a Fix the duration computation in triton_trtllm/client_grpc.py (#1071)
* Update client_grpc.py

Using `actual_duration` to compute metrics like RTF.
2025-06-04 11:18:00 +08:00
SWivid
35f130ee85 minor update for infer-gradio 2025-06-04 06:11:49 +08:00
SWivid
e6469f705f update shared.md 2025-06-03 22:09:13 +08:00
SWivid
31cd818095 formatting 2025-06-03 21:23:47 +08:00
Yushen CHEN
1d13664b24 Merge pull request #1063 from ionite34/dev
Fix finetune training with spaces in file paths
2025-06-03 21:18:41 +08:00
Yushen CHEN
b27471ea06 Merge pull request #1072 from hvoss-techfak/main
German Model support
2025-06-03 21:18:25 +08:00
Hendric Voss
8fb55f107e Update SHARED.md 2025-06-03 14:08:30 +02:00
Hendric Voss
ccb380b752 Added German Model 2025-06-03 14:08:03 +02:00
Ionite
3027b43953 Fix training with file path spaces 2025-05-28 15:24:35 -04:00
SWivid
ecd1c3949a Add py312 check for tempfile delete_on_close keyword 2025-05-22 23:10:29 +08:00
SWivid
2968aa184f v1.1.5 several fixes 2025-05-22 17:41:10 +08:00
SWivid
fb26b6d93e Fix #1046 tempfile related bug 2025-05-22 17:40:14 +08:00
SWivid
f7f266cdd9 preprocess only once. Fix #1043 2025-05-21 02:26:05 +08:00
SWivid
695c735737 Exclude broken dependency version with accelerate 2025-05-16 17:48:41 +08:00
SWivid
3e2a07da1d Update README.md & minor fixes 2025-05-11 19:40:37 +08:00
SWivid
c47687487c minor fix for vocab check in finetune_gradio 2025-05-05 23:32:00 +08:00
14 changed files with 173 additions and 69 deletions

View File

@@ -91,7 +91,7 @@ conda activate f5-tts
> ```bash
> git clone https://github.com/SWivid/F5-TTS.git
> cd F5-TTS
> # git submodule update --init --recursive # (optional, if need > bigvgan)
> # git submodule update --init --recursive # (optional, if use bigvgan as vocoder)
> pip install -e .
> ```

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.4"
version = "1.1.5"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -14,7 +14,7 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"accelerate>=0.33.0",
"accelerate>=0.33.0,!=1.7.0",
"bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
"cached_path",
"click",

View File

@@ -13,7 +13,7 @@ To avoid possible inference failures, make sure you have seen through the follow
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") <ins>to explicitly introduce some pauses</ins>.
- If English punctuation marks the end of a sentence, make sure there is a space " " after it. Otherwise not regarded as when chunk.
- <ins>Preprocess numbers</ins> to Chinese letters if you want to have them read in Chinese, otherwise in English.
- If the generation output is blank (pure silence), <ins>check for ffmpeg installation</ins>.
- If the generation output is blank (pure silence), <ins>check for FFmpeg installation</ins>.
- Try <ins>turn off `use_ema` if using an early-stage</ins> finetuned checkpoint (which goes just few updates).
@@ -129,6 +129,28 @@ ref_text = ""
```
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
## API Usage
```python
from importlib.resources import files
from f5_tts.api import F5TTS
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
```
Check [api.py](../api.py) for more details.
## TensorRT-LLM Deployment
See [detailed instructions](../runtime/triton_trtllm/README.md) for more information.
## Socket Real-time Service
Real-time voice output with chunk stream:

View File

@@ -22,6 +22,8 @@
- [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
- [French](#french)
- [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
- [German](#german)
- [F5-TTS Base @ de @ hvoss-techfak](#f5-tts-base--de--hvoss-techfak)
- [Hindi](#hindi)
- [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
- [Italian](#italian)
@@ -97,6 +99,22 @@ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
## German
#### F5-TTS Base @ de @ hvoss-techfak
|Model|🤗Hugging Face|Data (Hours)|Model License|
|:---:|:------------:|:-----------:|:-------------:|
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/hvoss-techfak/F5-TTS-German)|[Mozilla Common Voice 19.0](https://commonvoice.mozilla.org/en/datasets) & 800 hours Crowdsourced |cc-by-nc-4.0|
```bash
Model: hf://hvoss-techfak/F5-TTS-German/model_f5tts_german.pt
Vocab: hf://hvoss-techfak/F5-TTS-German/vocab.txt
Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "text_mask_padding": False, "conv_layers": 4, "pe_attn_head": 1}
```
- Finetuned by [@hvoss-techfak](https://github.com/hvoss-techfak)
## Hindi
#### F5-TTS Small @ hi @ SPRINGLab

View File

@@ -323,7 +323,7 @@ def main():
ref_text_ = voices[voice]["ref_text"]
gen_text_ = text.strip()
print(f"Voice: {voice}")
audio_segment, final_sample_rate, spectragram = infer_process(
audio_segment, final_sample_rate, spectrogram = infer_process(
ref_audio_,
ref_text_,
gen_text_,

View File

@@ -3,6 +3,7 @@
import gc
import json
import os
import re
import tempfile
from collections import OrderedDict
@@ -41,6 +42,7 @@ from f5_tts.infer.utils_infer import (
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
tempfile_kwargs,
)
from f5_tts.model import DiT, UNetT
@@ -126,7 +128,7 @@ def load_text_from_file(file):
return gr.update(value=text)
@lru_cache(maxsize=100) # NOTE. need to ensure params of infer() hashable
@lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
@gpu_decorator
def infer(
ref_audio_orig,
@@ -189,28 +191,24 @@ def infer(
# Remove silence
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
try:
sf.write(temp_path, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
finally:
os.unlink(temp_path)
final_wave = final_wave.squeeze().cpu().numpy()
# Save the spectrogram
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")
with gr.Blocks() as app_tts:
gr.Markdown("# Batched TTS")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
@@ -314,6 +312,12 @@ with gr.Blocks() as app_tts:
outputs=[ref_text_input],
)
ref_audio_input.clear(
lambda: [None, None],
None,
[ref_text_input, ref_text_file],
)
generate_btn.click(
basic_tts,
inputs=[
@@ -926,6 +930,16 @@ Have a conversation with an AI using your reference voice!
)
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
* [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
""")
with gr.Blocks() as app:
gr.Markdown(
f"""

View File

@@ -33,6 +33,7 @@ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
_ref_audio_cache = {}
_ref_text_cache = {}
device = (
"cuda"
@@ -44,6 +45,8 @@ device = (
else "cpu"
)
tempfile_kwargs = {"delete_on_close": False} if sys.version_info >= (3, 12) else {"delete": False}
# -----------------------------------------
target_sample_rate = 24000
@@ -290,62 +293,74 @@ def remove_silence_edges(audio, silence_threshold=-42):
# preprocess reference audio and text
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print):
def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
# Compute a hash of the reference audio file
with open(ref_audio_orig, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
show_info("Using cached preprocessed reference audio...")
ref_audio = _ref_audio_cache[audio_hash]
else: # first pass, do preprocess
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
temp_path = f.name
aseg = AudioSegment.from_file(ref_audio_orig)
if clip_short:
# 1. try to find long silence for clipping
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (1)")
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 12000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 12000:
show_info("Audio is over 12s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
# 3. if no proper silence found for clipping
if len(aseg) > 12000:
aseg = aseg[:12000]
show_info("Audio is over 12s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
ref_audio = f.name
aseg.export(temp_path, format="wav")
ref_audio = temp_path
# Compute a hash of the reference audio file
with open(ref_audio, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
# Cache the processed reference audio
_ref_audio_cache[audio_hash] = ref_audio
if not ref_text.strip():
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
global _ref_text_cache
if audio_hash in _ref_text_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
ref_text = _ref_text_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_audio_cache[audio_hash] = ref_text
_ref_text_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
@@ -384,7 +399,7 @@ def infer_process(
):
# Split the input text into batches
audio, sr = torchaudio.load(ref_audio)
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr) * speed)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
for i, gen_text in enumerate(gen_text_batches):
print(f"gen_text {i}", gen_text)

View File

@@ -22,6 +22,7 @@ from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import (
default,
exists,
get_epss_timesteps,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
@@ -92,6 +93,7 @@ class CFM(nn.Module):
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
use_epss=True,
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
@@ -190,7 +192,10 @@ class CFM(nn.Module):
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
else:
t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)

View File

@@ -189,3 +189,22 @@ def repetition_found(text, length=2, tolerance=10):
if count > tolerance:
return True
return False
# get the empirically pruned step for sampling
def get_epss_timesteps(n, device, dtype):
dt = 1 / 32
predefined_timesteps = {
5: [0, 2, 4, 8, 16, 32],
6: [0, 2, 4, 6, 8, 16, 32],
7: [0, 2, 4, 6, 8, 16, 24, 32],
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
}
t = predefined_timesteps.get(n, [])
if not t:
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
return dt * torch.tensor(t, device=device, dtype=dtype)

View File

@@ -220,8 +220,8 @@ def get_args():
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
@@ -244,7 +244,7 @@ async def send(
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
save_sample_rate: int = 24000,
):
total_duration = 0.0
latency_data = []
@@ -254,7 +254,7 @@ async def send(
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=24000)
duration = len(waveform) / sample_rate
lengths = np.array([[len(waveform)]], dtype=np.int32)
@@ -310,8 +310,9 @@ async def send(
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, estimated_target_duration))
total_duration += estimated_target_duration
actual_duration = len(audio) / save_sample_rate
latency_data.append((end, actual_duration))
total_duration += actual_duration
return total_duration, latency_data
@@ -416,7 +417,7 @@ async def main():
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=24000 if args.model_name == "f5_tts" else 16000,
save_sample_rate=24000,
)
)
tasks.append(task)

View File

@@ -82,7 +82,7 @@ def prepare_request(
samples,
reference_text,
target_text,
sample_rate=16000,
sample_rate=24000,
audio_save_dir: str = "./",
):
assert len(samples.shape) == 1, "samples should be 1D"
@@ -106,8 +106,8 @@ def prepare_request(
return data
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
def load_audio(wav_path, target_sample_rate=24000):
assert target_sample_rate == 24000, "hard coding in server"
if isinstance(wav_path, dict):
samples = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
@@ -129,7 +129,7 @@ if __name__ == "__main__":
url = f"{server_url}/v2/models/{args.model_name}/infer"
samples, sr = load_audio(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
assert sr == 24000, "sample rate hardcoded in server"
samples = np.array(samples, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)

View File

@@ -33,7 +33,7 @@ parameters [
},
{
key: "reference_audio_sample_rate",
value: {string_value:"16000"}
value: {string_value:"24000"}
},
{
key: "vocoder",

View File

@@ -1,5 +1,11 @@
# Training
Check your FFmpeg installation:
```bash
ffmpeg -version
```
If not found, install it first (or skip assuming you know of other backends available).
## Prepare Dataset
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.

View File

@@ -434,7 +434,7 @@ def start_training(
fp16 = ""
cmd = (
f"accelerate launch {fp16} {file_train} --exp_name {exp_name}"
f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
f" --learning_rate {learning_rate}"
f" --batch_size_per_gpu {batch_size_per_gpu}"
f" --batch_size_type {batch_size_type}"
@@ -453,7 +453,7 @@ def start_training(
cmd += " --finetune"
if file_checkpoint_train != "":
cmd += f" --pretrain {file_checkpoint_train}"
cmd += f' --pretrain "{file_checkpoint_train}"'
if tokenizer_file != "":
cmd += f" --tokenizer_path {tokenizer_file}"
@@ -1099,7 +1099,7 @@ def vocab_extend(project_name, symbols, model_type):
return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
def vocab_check(project_name):
def vocab_check(project_name, tokenizer_type):
name_project = project_name
path_project = os.path.join(path_data, name_project)
@@ -1128,6 +1128,8 @@ def vocab_check(project_name):
continue
text = sp[1].lower().strip()
if tokenizer_type == "pinyin":
text = convert_char_to_pinyin([text], polyphone=True)[0]
for t in text:
if t not in vocab and t not in miss_symbols_keep:
@@ -1498,7 +1500,9 @@ Using the extended model, you can finetune to a new language that is missing sym
txt_info_extend = gr.Textbox(label="Info", value="")
txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
check_button.click(
fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
)
extend_button.click(
fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
)