mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-15 06:23:22 -08:00
Compare commits
7 Commits
1dcb4e10f7
...
1.1.12
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39617fcf7a | ||
|
|
5b82f97c26 | ||
|
|
9ae46c8360 | ||
|
|
3eecd94baa | ||
|
|
d9a69452ce | ||
|
|
bc15df2b57 | ||
|
|
9b2357a1b9 |
@@ -30,6 +30,9 @@
|
||||
# Create a conda env with python_version>=3.10 (you could also use virtualenv)
|
||||
conda create -n f5-tts python=3.11
|
||||
conda activate f5-tts
|
||||
|
||||
# Install FFmpeg if you haven't yet
|
||||
conda install ffmpeg
|
||||
```
|
||||
|
||||
### Install PyTorch with matched device
|
||||
@@ -39,7 +42,11 @@ conda activate f5-tts
|
||||
|
||||
> ```bash
|
||||
> # Install pytorch with your CUDA version, e.g.
|
||||
> pip install torch==2.8.0+cu128 torchaudio==2.8.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
|
||||
>
|
||||
> # And also possible previous versions, e.g.
|
||||
> pip install torch==2.4.0+cu124 torchaudio==2.4.0+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
||||
> # etc.
|
||||
> ```
|
||||
|
||||
</details>
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.9"
|
||||
version = "1.1.12"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -20,15 +20,15 @@ dependencies = [
|
||||
"click",
|
||||
"datasets",
|
||||
"ema_pytorch>=0.5.2",
|
||||
"gradio>=5.0.0",
|
||||
"gradio>=6.0.0",
|
||||
"hydra-core>=1.3.0",
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy<=1.26.4; python_version<='3.10'",
|
||||
"pydantic<=2.10.6",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
"rjieba",
|
||||
"safetensors",
|
||||
"soundfile",
|
||||
"tomli",
|
||||
|
||||
@@ -577,7 +577,7 @@ with gr.Blocks() as app_multistyle:
|
||||
label="Cherry-pick Interface",
|
||||
lines=10,
|
||||
max_lines=40,
|
||||
show_copy_button=True,
|
||||
buttons=["copy"], # show_copy_button=True if gradio<6.0
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
@@ -816,7 +816,9 @@ Have a conversation with an AI using your reference voice!
|
||||
lines=2,
|
||||
)
|
||||
|
||||
chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
|
||||
chatbot_interface = gr.Chatbot(
|
||||
label="Conversation"
|
||||
) # type="messages" hard-coded and no need to pass in since gradio 6.0
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -853,6 +855,10 @@ Have a conversation with an AI using your reference voice!
|
||||
@gpu_decorator
|
||||
def generate_text_response(conv_state, system_prompt):
|
||||
"""Generate text response from AI"""
|
||||
for single_state in conv_state:
|
||||
if isinstance(single_state["content"], list):
|
||||
assert len(single_state["content"]) == 1 and single_state["content"][0]["type"] == "text"
|
||||
single_state["content"] = single_state["content"][0]["text"]
|
||||
|
||||
system_prompt_state = [{"role": "system", "content": system_prompt}]
|
||||
response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
|
||||
@@ -866,7 +872,7 @@ Have a conversation with an AI using your reference voice!
|
||||
if not conv_state or not ref_audio:
|
||||
return None, ref_text, seed_input
|
||||
|
||||
last_ai_response = conv_state[-1]["content"]
|
||||
last_ai_response = conv_state[-1]["content"][0]["text"]
|
||||
if not last_ai_response or conv_state[-1]["role"] != "assistant":
|
||||
return None, ref_text, seed_input
|
||||
|
||||
@@ -1108,7 +1114,6 @@ def main(port, host, share, api, root_path, inbrowser):
|
||||
server_name=host,
|
||||
server_port=port,
|
||||
share=share,
|
||||
show_api=api,
|
||||
root_path=root_path,
|
||||
inbrowser=inbrowser,
|
||||
)
|
||||
|
||||
@@ -43,7 +43,7 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
@@ -51,33 +51,29 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
|
||||
def average_upsample_text_by_mask(self, text, text_mask):
|
||||
batch, text_len, text_dim = text.shape
|
||||
|
||||
if audio_mask is None:
|
||||
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
|
||||
valid_mask = audio_mask & text_mask
|
||||
audio_lens = audio_mask.sum(dim=1) # [batch]
|
||||
valid_lens = valid_mask.sum(dim=1) # [batch]
|
||||
audio_len = text_len # cuz text already padded to same length as audio sequence
|
||||
text_lens = text_mask.sum(dim=1) # [batch]
|
||||
|
||||
upsampled_text = torch.zeros_like(text)
|
||||
|
||||
for i in range(batch):
|
||||
audio_len = audio_lens[i].item()
|
||||
valid_len = valid_lens[i].item()
|
||||
text_len = text_lens[i].item()
|
||||
|
||||
if valid_len == 0:
|
||||
if text_len == 0:
|
||||
continue
|
||||
|
||||
valid_ind = torch.where(valid_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
|
||||
valid_ind = torch.where(text_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
|
||||
|
||||
base_repeat = audio_len // valid_len
|
||||
remainder = audio_len % valid_len
|
||||
base_repeat = audio_len // text_len
|
||||
remainder = audio_len % text_len
|
||||
|
||||
indices = []
|
||||
for j in range(valid_len):
|
||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||
for j in range(text_len):
|
||||
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
|
||||
indices.extend([j] * repeat_count)
|
||||
|
||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||
@@ -87,7 +83,7 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
return upsampled_text
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None):
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False):
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||
@@ -114,7 +110,7 @@ class TextEmbedding(nn.Module):
|
||||
text = self.text_blocks(text)
|
||||
|
||||
if self.average_upsampling:
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask)
|
||||
|
||||
return text
|
||||
|
||||
@@ -247,17 +243,16 @@ class DiT(nn.Module):
|
||||
):
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
if audio_mask is None:
|
||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
|
||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
|
||||
else:
|
||||
batch = x.shape[0]
|
||||
seq_lens = audio_mask.sum(dim=1)
|
||||
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
|
||||
text_embed_list = []
|
||||
for i in range(batch):
|
||||
text_embed_i = self.text_embed(
|
||||
text[i].unsqueeze(0),
|
||||
seq_lens[i].item(),
|
||||
seq_len=seq_lens[i].item(),
|
||||
drop_text=drop_text,
|
||||
audio_mask=audio_mask,
|
||||
)
|
||||
text_embed_list.append(text_embed_i[0])
|
||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||
|
||||
@@ -7,7 +7,7 @@ import random
|
||||
from collections import defaultdict
|
||||
from importlib.resources import files
|
||||
|
||||
import jieba
|
||||
import rjieba
|
||||
import torch
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@@ -146,10 +146,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
|
||||
|
||||
def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
if jieba.dt.initialized is False:
|
||||
jieba.default_logger.setLevel(50) # CRITICAL
|
||||
jieba.initialize()
|
||||
|
||||
final_text_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
@@ -163,7 +159,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
for text in text_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
for seg in rjieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
FROM nvcr.io/nvidia/tritonserver:24.12-py3
|
||||
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
|
||||
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rjieba pypinyin librosa vocos
|
||||
WORKDIR /workspace
|
||||
@@ -26,7 +26,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import jieba
|
||||
import rjieba
|
||||
import torch
|
||||
import torchaudio
|
||||
import triton_python_backend_utils as pb_utils
|
||||
@@ -66,7 +66,7 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||||
for text in reference_target_texts_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
for seg in rjieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
|
||||
@@ -225,5 +225,5 @@ if __name__ == "__main__":
|
||||
# bad zh asr cnt 230435 (samples)
|
||||
# bad eh asr cnt 37217 (samples)
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
@@ -122,5 +122,5 @@ if __name__ == "__main__":
|
||||
# - - 1459 (polyphone)
|
||||
# char vocab size 5264 5219 5042
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
Reference in New Issue
Block a user