|
|
|
|
@@ -89,6 +89,12 @@ fix_duration = [
|
|
|
|
|
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
|
|
|
|
|
# fix_duration = None # use origin text duration
|
|
|
|
|
|
|
|
|
|
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
|
|
|
|
|
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
|
|
|
|
# target_text = "对,这就是你,万人敬仰的李白金星。"
|
|
|
|
|
# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]]
|
|
|
|
|
# fix_duration = [1.284, 2.677]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------#
|
|
|
|
|
|
|
|
|
|
@@ -138,28 +144,55 @@ if rms < target_rms:
|
|
|
|
|
if sr != target_sample_rate:
|
|
|
|
|
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
|
|
|
|
audio = resampler(audio)
|
|
|
|
|
offset = 0
|
|
|
|
|
audio_ = torch.zeros(1, 0)
|
|
|
|
|
edit_mask = torch.zeros(1, 0, dtype=torch.bool)
|
|
|
|
|
|
|
|
|
|
# Convert to mel spectrogram FIRST (on clean original audio)
|
|
|
|
|
# This avoids boundary artifacts from mel windows straddling zeros and real audio
|
|
|
|
|
audio = audio.to(device)
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames)
|
|
|
|
|
original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel)
|
|
|
|
|
|
|
|
|
|
# Build mel_cond and edit_mask at FRAME level
|
|
|
|
|
# Insert zero frames in mel domain instead of zero samples in wav domain
|
|
|
|
|
offset_frame = 0
|
|
|
|
|
mel_cond = torch.zeros(1, 0, n_mel_channels, device=device)
|
|
|
|
|
edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device)
|
|
|
|
|
fix_dur_list = fix_duration.copy() if fix_duration is not None else None
|
|
|
|
|
|
|
|
|
|
for part in parts_to_edit:
|
|
|
|
|
start, end = part
|
|
|
|
|
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
|
|
|
|
part_dur = part_dur * target_sample_rate
|
|
|
|
|
start = start * target_sample_rate
|
|
|
|
|
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
|
|
|
|
|
part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0)
|
|
|
|
|
|
|
|
|
|
# Convert to frames (this is the authoritative unit)
|
|
|
|
|
start_frame = round(start * target_sample_rate / hop_length)
|
|
|
|
|
end_frame = round(end * target_sample_rate / hop_length)
|
|
|
|
|
part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length)
|
|
|
|
|
|
|
|
|
|
# Number of frames for the kept (non-edited) region
|
|
|
|
|
keep_frames = start_frame - offset_frame
|
|
|
|
|
|
|
|
|
|
# Build mel_cond: original mel frames + zero frames for edit region
|
|
|
|
|
mel_cond = torch.cat(
|
|
|
|
|
(
|
|
|
|
|
mel_cond,
|
|
|
|
|
original_mel[:, offset_frame:start_frame, :],
|
|
|
|
|
torch.zeros(1, part_dur_frames, n_mel_channels, device=device),
|
|
|
|
|
),
|
|
|
|
|
dim=1,
|
|
|
|
|
)
|
|
|
|
|
edit_mask = torch.cat(
|
|
|
|
|
(
|
|
|
|
|
edit_mask,
|
|
|
|
|
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
|
|
|
|
|
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
|
|
|
|
|
torch.ones(1, keep_frames, dtype=torch.bool, device=device),
|
|
|
|
|
torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device),
|
|
|
|
|
),
|
|
|
|
|
dim=-1,
|
|
|
|
|
)
|
|
|
|
|
offset = end * target_sample_rate
|
|
|
|
|
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
|
|
|
|
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
|
|
|
|
audio = audio.to(device)
|
|
|
|
|
edit_mask = edit_mask.to(device)
|
|
|
|
|
offset_frame = end_frame
|
|
|
|
|
|
|
|
|
|
# Append remaining mel frames after last edit
|
|
|
|
|
mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1)
|
|
|
|
|
edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True)
|
|
|
|
|
|
|
|
|
|
# Text
|
|
|
|
|
text_list = [target_text]
|
|
|
|
|
@@ -170,14 +203,13 @@ else:
|
|
|
|
|
print(f"text : {text_list}")
|
|
|
|
|
print(f"pinyin: {final_text_list}")
|
|
|
|
|
|
|
|
|
|
# Duration
|
|
|
|
|
ref_audio_len = 0
|
|
|
|
|
duration = audio.shape[-1] // hop_length
|
|
|
|
|
# Duration - use mel_cond length (not raw audio length)
|
|
|
|
|
duration = mel_cond.shape[1]
|
|
|
|
|
|
|
|
|
|
# Inference
|
|
|
|
|
# Inference - pass mel_cond directly (not wav)
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
generated, trajectory = model.sample(
|
|
|
|
|
cond=audio,
|
|
|
|
|
cond=mel_cond, # Now passing mel directly, not wav
|
|
|
|
|
text=final_text_list,
|
|
|
|
|
duration=duration,
|
|
|
|
|
steps=nfe_step,
|
|
|
|
|
@@ -190,7 +222,6 @@ with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
# Final result
|
|
|
|
|
generated = generated.to(torch.float32)
|
|
|
|
|
generated = generated[:, ref_audio_len:, :]
|
|
|
|
|
gen_mel_spec = generated.permute(0, 2, 1)
|
|
|
|
|
if mel_spec_type == "vocos":
|
|
|
|
|
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
|
|
|
|
|