mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-05 09:39:52 -08:00
Fix speech editing boundary artifacts by working in mel domain
Previously, speech_edit.py worked in wav domain (inserting zeros into the waveform before computing mel spectrogram), which caused boundary artifacts when mel spectrogram windows straddled zeros and real audio. This commit refactors the approach to work in mel domain: - Compute mel spectrogram on the clean original audio first - Insert zero frames in mel domain instead of zero samples in wav domain - Use frame-level granularity throughout for consistency Benefits: - Eliminates boundary artifacts - More consistent behavior regardless of small float variations in input times - Cleaner edit boundaries Changes to speech_edit.py (lines 148-220): - Convert audio to mel using model.mel_spec() before editing - Build mel_cond by concatenating original mel frames + zero frames - Calculate all time-based values at frame level first, then convert to samples - Pass mel_cond directly to model.sample() instead of raw audio
This commit is contained in:
@@ -89,6 +89,12 @@ fix_duration = [
|
||||
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
|
||||
# fix_duration = None # use origin text duration
|
||||
|
||||
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
|
||||
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
||||
# target_text = "对,这就是你,万人敬仰的李白金星。"
|
||||
# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]]
|
||||
# fix_duration = [1.284, 2.677]
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
@@ -138,28 +144,55 @@ if rms < target_rms:
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
offset = 0
|
||||
audio_ = torch.zeros(1, 0)
|
||||
edit_mask = torch.zeros(1, 0, dtype=torch.bool)
|
||||
|
||||
# Convert to mel spectrogram FIRST (on clean original audio)
|
||||
# This avoids boundary artifacts from mel windows straddling zeros and real audio
|
||||
audio = audio.to(device)
|
||||
with torch.inference_mode():
|
||||
original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames)
|
||||
original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel)
|
||||
|
||||
# Build mel_cond and edit_mask at FRAME level
|
||||
# Insert zero frames in mel domain instead of zero samples in wav domain
|
||||
offset_frame = 0
|
||||
mel_cond = torch.zeros(1, 0, n_mel_channels, device=device)
|
||||
edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device)
|
||||
fix_dur_list = fix_duration.copy() if fix_duration is not None else None
|
||||
|
||||
for part in parts_to_edit:
|
||||
start, end = part
|
||||
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
||||
part_dur = part_dur * target_sample_rate
|
||||
start = start * target_sample_rate
|
||||
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
|
||||
part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0)
|
||||
|
||||
# Convert to frames (this is the authoritative unit)
|
||||
start_frame = round(start * target_sample_rate / hop_length)
|
||||
end_frame = round(end * target_sample_rate / hop_length)
|
||||
part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length)
|
||||
|
||||
# Number of frames for the kept (non-edited) region
|
||||
keep_frames = start_frame - offset_frame
|
||||
|
||||
# Build mel_cond: original mel frames + zero frames for edit region
|
||||
mel_cond = torch.cat(
|
||||
(
|
||||
mel_cond,
|
||||
original_mel[:, offset_frame:start_frame, :],
|
||||
torch.zeros(1, part_dur_frames, n_mel_channels, device=device),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
edit_mask = torch.cat(
|
||||
(
|
||||
edit_mask,
|
||||
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
|
||||
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
|
||||
torch.ones(1, keep_frames, dtype=torch.bool, device=device),
|
||||
torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
offset = end * target_sample_rate
|
||||
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
||||
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
||||
audio = audio.to(device)
|
||||
edit_mask = edit_mask.to(device)
|
||||
offset_frame = end_frame
|
||||
|
||||
# Append remaining mel frames after last edit
|
||||
mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1)
|
||||
edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True)
|
||||
|
||||
# Text
|
||||
text_list = [target_text]
|
||||
@@ -170,14 +203,13 @@ else:
|
||||
print(f"text : {text_list}")
|
||||
print(f"pinyin: {final_text_list}")
|
||||
|
||||
# Duration
|
||||
ref_audio_len = 0
|
||||
duration = audio.shape[-1] // hop_length
|
||||
# Duration - use mel_cond length (not raw audio length)
|
||||
duration = mel_cond.shape[1]
|
||||
|
||||
# Inference
|
||||
# Inference - pass mel_cond directly (not wav)
|
||||
with torch.inference_mode():
|
||||
generated, trajectory = model.sample(
|
||||
cond=audio,
|
||||
cond=mel_cond, # Now passing mel directly, not wav
|
||||
text=final_text_list,
|
||||
duration=duration,
|
||||
steps=nfe_step,
|
||||
@@ -190,7 +222,6 @@ with torch.inference_mode():
|
||||
|
||||
# Final result
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
gen_mel_spec = generated.permute(0, 2, 1)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
|
||||
Reference in New Issue
Block a user