mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-01-16 06:53:17 -08:00
Compare commits
5 Commits
1dcb4e10f7
...
1.1.10
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ae46c8360 | ||
|
|
3eecd94baa | ||
|
|
d9a69452ce | ||
|
|
bc15df2b57 | ||
|
|
9b2357a1b9 |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "f5-tts"
|
||||
version = "1.1.9"
|
||||
version = "1.1.10"
|
||||
description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT License"}
|
||||
@@ -22,13 +22,13 @@ dependencies = [
|
||||
"ema_pytorch>=0.5.2",
|
||||
"gradio>=5.0.0",
|
||||
"hydra-core>=1.3.0",
|
||||
"jieba",
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy<=1.26.4; python_version<='3.10'",
|
||||
"pydantic<=2.10.6",
|
||||
"pydub",
|
||||
"pypinyin",
|
||||
"rjieba",
|
||||
"safetensors",
|
||||
"soundfile",
|
||||
"tomli",
|
||||
|
||||
@@ -43,7 +43,7 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
@@ -51,33 +51,29 @@ class TextEmbedding(nn.Module):
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
|
||||
def average_upsample_text_by_mask(self, text, text_mask):
|
||||
batch, text_len, text_dim = text.shape
|
||||
|
||||
if audio_mask is None:
|
||||
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
|
||||
valid_mask = audio_mask & text_mask
|
||||
audio_lens = audio_mask.sum(dim=1) # [batch]
|
||||
valid_lens = valid_mask.sum(dim=1) # [batch]
|
||||
audio_len = text_len # cuz text already padded to same length as audio sequence
|
||||
text_lens = text_mask.sum(dim=1) # [batch]
|
||||
|
||||
upsampled_text = torch.zeros_like(text)
|
||||
|
||||
for i in range(batch):
|
||||
audio_len = audio_lens[i].item()
|
||||
valid_len = valid_lens[i].item()
|
||||
text_len = text_lens[i].item()
|
||||
|
||||
if valid_len == 0:
|
||||
if text_len == 0:
|
||||
continue
|
||||
|
||||
valid_ind = torch.where(valid_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
|
||||
valid_ind = torch.where(text_mask[i])[0]
|
||||
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
|
||||
|
||||
base_repeat = audio_len // valid_len
|
||||
remainder = audio_len % valid_len
|
||||
base_repeat = audio_len // text_len
|
||||
remainder = audio_len % text_len
|
||||
|
||||
indices = []
|
||||
for j in range(valid_len):
|
||||
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
||||
for j in range(text_len):
|
||||
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
|
||||
indices.extend([j] * repeat_count)
|
||||
|
||||
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
|
||||
@@ -87,7 +83,7 @@ class TextEmbedding(nn.Module):
|
||||
|
||||
return upsampled_text
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None):
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False):
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
||||
@@ -114,7 +110,7 @@ class TextEmbedding(nn.Module):
|
||||
text = self.text_blocks(text)
|
||||
|
||||
if self.average_upsampling:
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
|
||||
text = self.average_upsample_text_by_mask(text, ~text_mask)
|
||||
|
||||
return text
|
||||
|
||||
@@ -247,17 +243,16 @@ class DiT(nn.Module):
|
||||
):
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
if audio_mask is None:
|
||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
|
||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
|
||||
else:
|
||||
batch = x.shape[0]
|
||||
seq_lens = audio_mask.sum(dim=1)
|
||||
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
|
||||
text_embed_list = []
|
||||
for i in range(batch):
|
||||
text_embed_i = self.text_embed(
|
||||
text[i].unsqueeze(0),
|
||||
seq_lens[i].item(),
|
||||
seq_len=seq_lens[i].item(),
|
||||
drop_text=drop_text,
|
||||
audio_mask=audio_mask,
|
||||
)
|
||||
text_embed_list.append(text_embed_i[0])
|
||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||
|
||||
@@ -7,7 +7,7 @@ import random
|
||||
from collections import defaultdict
|
||||
from importlib.resources import files
|
||||
|
||||
import jieba
|
||||
import rjieba
|
||||
import torch
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@@ -146,10 +146,6 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||
|
||||
|
||||
def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
if jieba.dt.initialized is False:
|
||||
jieba.default_logger.setLevel(50) # CRITICAL
|
||||
jieba.initialize()
|
||||
|
||||
final_text_list = []
|
||||
custom_trans = str.maketrans(
|
||||
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||
@@ -163,7 +159,7 @@ def convert_char_to_pinyin(text_list, polyphone=True):
|
||||
for text in text_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
for seg in rjieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
FROM nvcr.io/nvidia/tritonserver:24.12-py3
|
||||
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 jieba pypinyin librosa vocos
|
||||
RUN pip install tritonclient[grpc] tensorrt-llm==0.16.0 torchaudio==2.5.1 rjieba pypinyin librosa vocos
|
||||
WORKDIR /workspace
|
||||
@@ -26,7 +26,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import jieba
|
||||
import rjieba
|
||||
import torch
|
||||
import torchaudio
|
||||
import triton_python_backend_utils as pb_utils
|
||||
@@ -66,7 +66,7 @@ def convert_char_to_pinyin(reference_target_texts_list, polyphone=True):
|
||||
for text in reference_target_texts_list:
|
||||
char_list = []
|
||||
text = text.translate(custom_trans)
|
||||
for seg in jieba.cut(text):
|
||||
for seg in rjieba.cut(text):
|
||||
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||
|
||||
@@ -225,5 +225,5 @@ if __name__ == "__main__":
|
||||
# bad zh asr cnt 230435 (samples)
|
||||
# bad eh asr cnt 37217 (samples)
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
@@ -122,5 +122,5 @@ if __name__ == "__main__":
|
||||
# - - 1459 (polyphone)
|
||||
# char vocab size 5264 5219 5042
|
||||
|
||||
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# vocab size may be slightly different due to rjieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
||||
# please be careful if using pretrained model, make sure the vocab.txt is same
|
||||
|
||||
Reference in New Issue
Block a user