10 Commits

Author SHA1 Message Date
Yushen CHEN
27e20fcf39 Merge pull request #1242 from acadarmeria/fix-speech-edit-mel-domain
Fix speech editing boundary artifacts by working in mel domain
2025-12-26 17:35:57 +08:00
acadarmeria
dff57ebd2a Fix speech editing boundary artifacts by working in mel domain
Previously, speech_edit.py worked in wav domain (inserting zeros into the waveform before computing mel spectrogram), which caused boundary artifacts when mel spectrogram windows straddled zeros and real audio.

This commit refactors the approach to work in mel domain:
- Compute mel spectrogram on the clean original audio first
- Insert zero frames in mel domain instead of zero samples in wav domain
- Use frame-level granularity throughout for consistency

Benefits:
- Eliminates boundary artifacts
- More consistent behavior regardless of small float variations in input times
- Cleaner edit boundaries

Changes to speech_edit.py (lines 148-220):
- Convert audio to mel using model.mel_spec() before editing
- Build mel_cond by concatenating original mel frames + zero frames
- Calculate all time-based values at frame level first, then convert to samples
- Pass mel_cond directly to model.sample() instead of raw audio
2025-12-26 08:49:57 +00:00
SWivid
46ccc575c5 v1.1.15 workaround for gr.Accordion default open=False bug (#1239) 2025-12-21 15:06:44 +08:00
SWivid
39617fcf7a v1.1.12 bump gradio from 5.0 to 6.0, several fixes to ensure compatibility with new gradio version 2025-12-20 18:44:43 +08:00
Yushen Chen
5b82f97c26 fix #1239, use gradio>=6.0; add more clear instruction for ffmpeg installation (#1234) 2025-12-20 16:08:13 +08:00
SWivid
9ae46c8360 Replace jieba pkg with rjieba - a jieba-rs Python binding 2025-11-28 13:08:07 +00:00
SWivid
3eecd94baa support back avg upsampling for batch, cover up non-mask case 2025-11-09 11:56:03 +00:00
SWivid
d9a69452ce formatting 2025-11-09 18:25:30 +08:00
Yushen CHEN
bc15df2b57 Merge pull request #1212 from QingyuLiu0521/fix/AverageUpsampling
Fix Average Upsampling conflict logic, introduced from the previous batch inference fix.
2025-11-09 18:23:38 +08:00
QingyuLiu0521
9b2357a1b9 Fix Average Upsampling 2025-11-08 18:39:06 -05:00
10 changed files with 106 additions and 61 deletions

View File

@@ -30,6 +30,9 @@
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
conda create -n f5-tts python=3.11
conda activate f5-tts
# Install FFmpeg if you haven't yet
conda install ffmpeg
```
### Install PyTorch with matched device
@@ -39,7 +42,11 @@ conda activate f5-tts
> ```bash
> # Install pytorch with your CUDA version, e.g.
> pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
>
> # And also possible previous versions, e.g.
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
> # etc.
> ```
</details>

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "f5-tts"
version = "1.1.9"
version = "1.1.15"
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
readme = "README.md"
license = {text = "MIT License"}
@@ -20,15 +20,15 @@ dependencies = [
"click",
"datasets",
"ema_pytorch>=0.5.2",
"gradio>=5.0.0",
"gradio>=6.0.0",
"hydra-core>=1.3.0",
"jieba",
"librosa",
"matplotlib",
"numpy<=1.26.4; python_version<='3.10'",
"pydantic<=2.10.6",
"pydub",
"pypinyin",
"rjieba",
"safetensors",
"soundfile",
"tomli",

View File

@@ -221,7 +221,7 @@ with gr.Blocks() as app_tts:
)
gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
generate_btn = gr.Button("Synthesize", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
with gr.Accordion("Advanced Settings", open=True) as adv_settn:
with gr.Row():
ref_text_input = gr.Textbox(
label="Reference Text",
@@ -269,6 +269,17 @@ with gr.Blocks() as app_tts:
info="Set the duration of the cross-fade between audio clips.",
)
def collapse_accordion():
return gr.Accordion(open=False)
# Workaround for https://github.com/SWivid/F5-TTS/issues/1239#issuecomment-3677987413
# i.e. to set gr.Accordion(open=True) by default, then collapse manually Blocks loaded
app_tts.load(
fn=collapse_accordion,
inputs=None,
outputs=adv_settn,
)
audio_output = gr.Audio(label="Synthesized Audio")
spectrogram_output = gr.Image(label="Spectrogram")
@@ -577,7 +588,7 @@ with gr.Blocks() as app_multistyle:
label="Cherry-pick Interface",
lines=10,
max_lines=40,
show_copy_button=True,
buttons=["copy"], # show_copy_button=True if gradio<6.0
interactive=False,
visible=False,
)
@@ -816,7 +827,9 @@ Have a conversation with an AI using your reference voice!
lines=2,
)
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
chatbot_interface = gr.Chatbot(
label="Conversation"
) # type="messages" hard-coded and no need to pass in since gradio 6.0
with gr.Row():
with gr.Column():
@@ -853,6 +866,10 @@ Have a conversation with an AI using your reference voice!
@gpu_decorator
def generate_text_response(conv_state, system_prompt):
"""Generate text response from AI"""
for single_state in conv_state:
if isinstance(single_state["content"], list):
assert len(single_state["content"]) == 1 and single_state["content"][0]["type"] == "text"
single_state["content"] = single_state["content"][0]["text"]
system_prompt_state = [{"role": "system", "content": system_prompt}]
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
@@ -866,7 +883,7 @@ Have a conversation with an AI using your reference voice!
if not conv_state or not ref_audio:
return None, ref_text, seed_input
last_ai_response = conv_state[-1]["content"]
last_ai_response = conv_state[-1]["content"][0]["text"]
if not last_ai_response or conv_state[-1]["role"] != "assistant":
return None, ref_text, seed_input
@@ -1108,7 +1125,6 @@ def main(port, host, share, api, root_path, inbrowser):
server_name=host,
server_port=port,
share=share,
show_api=api,
root_path=root_path,
inbrowser=inbrowser,
)

View File

@@ -89,6 +89,12 @@ fix_duration = [
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
# fix_duration = None # use origin text duration
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
# target_text = "对,这就是你,万人敬仰的李白金星。"
# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]]
# fix_duration = [1.284, 2.677]
# -------------------------------------------------#
@@ -138,28 +144,55 @@ if rms < target_rms:
if sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
audio = resampler(audio)
offset = 0
audio_ = torch.zeros(1, 0)
edit_mask = torch.zeros(1, 0, dtype=torch.bool)
# Convert to mel spectrogram FIRST (on clean original audio)
# This avoids boundary artifacts from mel windows straddling zeros and real audio
audio = audio.to(device)
with torch.inference_mode():
original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames)
original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel)
# Build mel_cond and edit_mask at FRAME level
# Insert zero frames in mel domain instead of zero samples in wav domain
offset_frame = 0
mel_cond = torch.zeros(1, 0, n_mel_channels, device=device)
edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device)
fix_dur_list = fix_duration.copy() if fix_duration is not None else None
for part in parts_to_edit:
start, end = part
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
part_dur = part_dur * target_sample_rate
start = start * target_sample_rate
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0)
# Convert to frames (this is the authoritative unit)
start_frame = round(start * target_sample_rate / hop_length)
end_frame = round(end * target_sample_rate / hop_length)
part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length)
# Number of frames for the kept (non-edited) region
keep_frames = start_frame - offset_frame
# Build mel_cond: original mel frames + zero frames for edit region
mel_cond = torch.cat(
(
mel_cond,
original_mel[:, offset_frame:start_frame, :],
torch.zeros(1, part_dur_frames, n_mel_channels, device=device),
),
dim=1,
)
edit_mask = torch.cat(
(
edit_mask,
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
torch.ones(1, keep_frames, dtype=torch.bool, device=device),
torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device),
),
dim=-1,
)
offset = end * target_sample_rate
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
audio = audio.to(device)
edit_mask = edit_mask.to(device)
offset_frame = end_frame
# Append remaining mel frames after last edit
mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1)
edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True)
# Text
text_list = [target_text]
@@ -170,14 +203,13 @@ else:
print(f"text : {text_list}")
print(f"pinyin: {final_text_list}")
# Duration
ref_audio_len = 0
duration = audio.shape[-1] // hop_length
# Duration - use mel_cond length (not raw audio length)
duration = mel_cond.shape[1]
# Inference
# Inference - pass mel_cond directly (not wav)
with torch.inference_mode():
generated, trajectory = model.sample(
cond=audio,
cond=mel_cond, # Now passing mel directly, not wav
text=final_text_list,
duration=duration,
steps=nfe_step,
@@ -190,7 +222,6 @@ with torch.inference_mode():
# Final result
generated = generated.to(torch.float32)
generated = generated[:, ref_audio_len:, :]
gen_mel_spec = generated.permute(0, 2, 1)
if mel_spec_type == "vocos":
generated_wave = vocoder.decode(gen_mel_spec).cpu()

View File

@@ -43,7 +43,7 @@ class TextEmbedding(nn.Module):
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
@@ -51,33 +51,29 @@ class TextEmbedding(nn.Module):
else:
self.extra_modeling = False
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
def average_upsample_text_by_mask(self, text, text_mask):
batch, text_len, text_dim = text.shape
if audio_mask is None:
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
valid_mask = audio_mask & text_mask
audio_lens = audio_mask.sum(dim=1) # [batch]
valid_lens = valid_mask.sum(dim=1) # [batch]
audio_len = text_len # cuz text already padded to same length as audio sequence
text_lens = text_mask.sum(dim=1) # [batch]
upsampled_text = torch.zeros_like(text)
for i in range(batch):
audio_len = audio_lens[i].item()
valid_len = valid_lens[i].item()
text_len = text_lens[i].item()
if valid_len == 0:
if text_len == 0:
continue
valid_ind = torch.where(valid_mask[i])[0]
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
valid_ind = torch.where(text_mask[i])[0]
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
base_repeat = audio_len // valid_len
remainder = audio_len % valid_len
base_repeat = audio_len // text_len
remainder = audio_len % text_len
indices = []
for j in range(valid_len):
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
for j in range(text_len):
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
indices.extend([j] * repeat_count)
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
@@ -87,7 +83,7 @@ class TextEmbedding(nn.Module):
return upsampled_text
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None):
def forward(self, text: int["b nt"], seq_len, drop_text=False):
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
@@ -114,7 +110,7 @@ class TextEmbedding(nn.Module):
text = self.text_blocks(text)
if self.average_upsampling:
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
text = self.average_upsample_text_by_mask(text, ~text_mask)
return text
@@ -247,17 +243,16 @@ class DiT(nn.Module):
):
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
else:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1)
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
text_embed_list = []
for i in range(batch):
text_embed_i = self.text_embed(
text[i].unsqueeze(0),
seq_lens[i].item(),
seq_len=seq_lens[i].item(),
drop_text=drop_text,
audio_mask=audio_mask,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)

View File

@@ -7,7 +7,7 @@ import random
from collections import defaultdict
from importlib.resources import files
import jieba
import rjieba
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
@@ -146,10 +146,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
def convert_char_to_pinyin(text_list, polyphone=True):
if jieba.dt.initialized is False:
jieba.default_logger.setLevel(50) # CRITICAL
jieba.initialize()
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
@@ -163,7 +159,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
for text in text_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
for seg in rjieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":

View File

@@ -1,3 +1,3 @@
FROM nvcr.io/nvidia/tritonserver:24.12-py3
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rjieba pypinyin librosa vocos
WORKDIR /workspace

View File

@@ -26,7 +26,7 @@
import json
import os
import jieba
import rjieba
import torch
import torchaudio
import triton_python_backend_utils as pb_utils
@@ -66,7 +66,7 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
for text in reference_target_texts_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
for seg in rjieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":

View File

@@ -225,5 +225,5 @@ if __name__ == "__main__":
# bad zh asr cnt 230435 (samples)
# bad eh asr cnt 37217 (samples)
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same

View File

@@ -122,5 +122,5 @@ if __name__ == "__main__":
# - - 1459 (polyphone)
# char vocab size 5264 5219 5042
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
# please be careful if using pretrained model, make sure the vocab.txt is same