mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-26 12:51:16 -08:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27e20fcf39 | ||
|
|
dff57ebd2a | ||
|
|
46ccc575c5 | ||
|
|
39617fcf7a | ||
|
|
5b82f97c26 |
@@ -30,6 +30,9 @@
|
||||
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
|
||||
conda create -n f5-tts python=3.11
|
||||
conda activate f5-tts
|
||||
|
||||
# Install FFmpeg if you haven't yet
|
||||
conda install ffmpeg
|
||||
```
|
||||
|
||||
### Install PyTorch with matched device
|
||||
@@ -39,7 +42,11 @@ conda activate f5-tts
|
||||
|
||||
> ```bash
|
||||
> # Install pytorch with your CUDA version, e.g.
|
||||
> pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
|
||||
>
|
||||
> # And also possible previous versions, e.g.
|
||||
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
||||
> # etc.
|
||||
> ```
|
||||
|
||||
</details>
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.10"
|
||||
version = "1.1.15"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -20,7 +20,7 @@ dependencies = [
|
||||
"click",
|
||||
"datasets",
|
||||
"ema_pytorch>=0.5.2",
|
||||
"gradio>=5.0.0",
|
||||
"gradio>=6.0.0",
|
||||
"hydra-core>=1.3.0",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
|
||||
@@ -221,7 +221,7 @@ with gr.Blocks() as app_tts:
|
||||
)
|
||||
gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
|
||||
generate_btn = gr.Button("Synthesize", variant="primary")
|
||||
with gr.Accordion("Advanced Settings", open=False):
|
||||
with gr.Accordion("Advanced Settings", open=True) as adv_settn:
|
||||
with gr.Row():
|
||||
ref_text_input = gr.Textbox(
|
||||
label="Reference Text",
|
||||
@@ -269,6 +269,17 @@ with gr.Blocks() as app_tts:
|
||||
info="Set the duration of the cross-fade between audio clips.",
|
||||
)
|
||||
|
||||
def collapse_accordion():
|
||||
return gr.Accordion(open=False)
|
||||
|
||||
# Workaround for https://github.com/SWivid/F5-TTS/issues/1239#issuecomment-3677987413
|
||||
# i.e. to set gr.Accordion(open=True) by default, then collapse manually Blocks loaded
|
||||
app_tts.load(
|
||||
fn=collapse_accordion,
|
||||
inputs=None,
|
||||
outputs=adv_settn,
|
||||
)
|
||||
|
||||
audio_output = gr.Audio(label="Synthesized Audio")
|
||||
spectrogram_output = gr.Image(label="Spectrogram")
|
||||
|
||||
@@ -577,7 +588,7 @@ with gr.Blocks() as app_multistyle:
|
||||
label="Cherry-pick Interface",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
show_copy_button=True,
|
||||
buttons=["copy"], # show_copy_button=True if gradio<6.0
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
@@ -816,7 +827,9 @@ Have a conversation with an AI using your reference voice!
|
||||
lines=2,
|
||||
)
|
||||
|
||||
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
|
||||
chatbot_interface = gr.Chatbot(
|
||||
label="Conversation"
|
||||
) # type="messages" hard-coded and no need to pass in since gradio 6.0
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -853,6 +866,10 @@ Have a conversation with an AI using your reference voice!
|
||||
@gpu_decorator
|
||||
def generate_text_response(conv_state, system_prompt):
|
||||
"""Generate text response from AI"""
|
||||
for single_state in conv_state:
|
||||
if isinstance(single_state["content"], list):
|
||||
assert len(single_state["content"]) == 1 and single_state["content"][0]["type"] == "text"
|
||||
single_state["content"] = single_state["content"][0]["text"]
|
||||
|
||||
system_prompt_state = [{"role": "system", "content": system_prompt}]
|
||||
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
|
||||
@@ -866,7 +883,7 @@ Have a conversation with an AI using your reference voice!
|
||||
if not conv_state or not ref_audio:
|
||||
return None, ref_text, seed_input
|
||||
|
||||
last_ai_response = conv_state[-1]["content"]
|
||||
last_ai_response = conv_state[-1]["content"][0]["text"]
|
||||
if not last_ai_response or conv_state[-1]["role"] != "assistant":
|
||||
return None, ref_text, seed_input
|
||||
|
||||
@@ -1108,7 +1125,6 @@ def main(port, host, share, api, root_path, inbrowser):
|
||||
server_name=host,
|
||||
server_port=port,
|
||||
share=share,
|
||||
show_api=api,
|
||||
root_path=root_path,
|
||||
inbrowser=inbrowser,
|
||||
)
|
||||
|
||||
@@ -89,6 +89,12 @@ fix_duration = [
|
||||
# parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
|
||||
# fix_duration = None # use origin text duration
|
||||
|
||||
# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
|
||||
# origin_text = "对,这就是我,万人敬仰的太乙真人。"
|
||||
# target_text = "对,这就是你,万人敬仰的李白金星。"
|
||||
# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]]
|
||||
# fix_duration = [1.284, 2.677]
|
||||
|
||||
|
||||
# -------------------------------------------------#
|
||||
|
||||
@@ -138,28 +144,55 @@ if rms < target_rms:
|
||||
if sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
||||
audio = resampler(audio)
|
||||
offset = 0
|
||||
audio_ = torch.zeros(1, 0)
|
||||
edit_mask = torch.zeros(1, 0, dtype=torch.bool)
|
||||
|
||||
# Convert to mel spectrogram FIRST (on clean original audio)
|
||||
# This avoids boundary artifacts from mel windows straddling zeros and real audio
|
||||
audio = audio.to(device)
|
||||
with torch.inference_mode():
|
||||
original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames)
|
||||
original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel)
|
||||
|
||||
# Build mel_cond and edit_mask at FRAME level
|
||||
# Insert zero frames in mel domain instead of zero samples in wav domain
|
||||
offset_frame = 0
|
||||
mel_cond = torch.zeros(1, 0, n_mel_channels, device=device)
|
||||
edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device)
|
||||
fix_dur_list = fix_duration.copy() if fix_duration is not None else None
|
||||
|
||||
for part in parts_to_edit:
|
||||
start, end = part
|
||||
part_dur = end - start if fix_duration is None else fix_duration.pop(0)
|
||||
part_dur = part_dur * target_sample_rate
|
||||
start = start * target_sample_rate
|
||||
audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
|
||||
part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0)
|
||||
|
||||
# Convert to frames (this is the authoritative unit)
|
||||
start_frame = round(start * target_sample_rate / hop_length)
|
||||
end_frame = round(end * target_sample_rate / hop_length)
|
||||
part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length)
|
||||
|
||||
# Number of frames for the kept (non-edited) region
|
||||
keep_frames = start_frame - offset_frame
|
||||
|
||||
# Build mel_cond: original mel frames + zero frames for edit region
|
||||
mel_cond = torch.cat(
|
||||
(
|
||||
mel_cond,
|
||||
original_mel[:, offset_frame:start_frame, :],
|
||||
torch.zeros(1, part_dur_frames, n_mel_channels, device=device),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
edit_mask = torch.cat(
|
||||
(
|
||||
edit_mask,
|
||||
torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
|
||||
torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
|
||||
torch.ones(1, keep_frames, dtype=torch.bool, device=device),
|
||||
torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
offset = end * target_sample_rate
|
||||
audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1)
|
||||
edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
|
||||
audio = audio.to(device)
|
||||
edit_mask = edit_mask.to(device)
|
||||
offset_frame = end_frame
|
||||
|
||||
# Append remaining mel frames after last edit
|
||||
mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1)
|
||||
edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True)
|
||||
|
||||
# Text
|
||||
text_list = [target_text]
|
||||
@@ -170,14 +203,13 @@ else:
|
||||
print(f"text : {text_list}")
|
||||
print(f"pinyin: {final_text_list}")
|
||||
|
||||
# Duration
|
||||
ref_audio_len = 0
|
||||
duration = audio.shape[-1] // hop_length
|
||||
# Duration - use mel_cond length (not raw audio length)
|
||||
duration = mel_cond.shape[1]
|
||||
|
||||
# Inference
|
||||
# Inference - pass mel_cond directly (not wav)
|
||||
with torch.inference_mode():
|
||||
generated, trajectory = model.sample(
|
||||
cond=audio,
|
||||
cond=mel_cond, # Now passing mel directly, not wav
|
||||
text=final_text_list,
|
||||
duration=duration,
|
||||
steps=nfe_step,
|
||||
@@ -190,7 +222,6 @@ with torch.inference_mode():
|
||||
|
||||
# Final result
|
||||
generated = generated.to(torch.float32)
|
||||
generated = generated[:, ref_audio_len:, :]
|
||||
gen_mel_spec = generated.permute(0, 2, 1)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
|
||||
Reference in New Issue
Block a user