mirror of
https://github.com/SWivid/F5-TTS.git
synced 2025-12-05 20:40:12 -08:00
runtime trtllm: minor fixes. pytorch: update text_embedding logic to correct v0 batching.
This commit is contained in:
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
@@ -236,19 +237,30 @@ class DiT(nn.Module):
|
||||
cache: bool = True,
|
||||
audio_mask: bool["b n"] | None = None, # noqa: F722
|
||||
):
|
||||
seq_len = x.shape[1]
|
||||
# TODO. modify to get text_embed one by one (to avoid misalignment when batching), as done in runtime imple.
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
batch = x.shape[0]
|
||||
seq_lens = audio_mask.sum(dim=1)
|
||||
text_embed_list = []
|
||||
for i in range(batch):
|
||||
text_embed_i = self.text_embed(
|
||||
text[i].unsqueeze(0),
|
||||
seq_lens[i].item(),
|
||||
drop_text=drop_text,
|
||||
audio_mask=audio_mask,
|
||||
)
|
||||
text_embed_list.append(text_embed_i[0])
|
||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
||||
if cache:
|
||||
if drop_text:
|
||||
self.text_uncond = text_embed
|
||||
else:
|
||||
self.text_cond = text_embed
|
||||
|
||||
if cache:
|
||||
if drop_text:
|
||||
if self.text_uncond is None:
|
||||
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
|
||||
text_embed = self.text_uncond
|
||||
else:
|
||||
if self.text_cond is None:
|
||||
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
|
||||
text_embed = self.text_cond
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
|
||||
@@ -42,13 +42,16 @@ class TextEmbedding(nn.Module):
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
||||
|
||||
def forward(self, text, seq_len):
|
||||
def forward(self, text, seq_len, drop_text=False):
|
||||
text = text + 1
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
|
||||
if self.mask_padding:
|
||||
text_mask = text == 0
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
text = text + self.freqs_cis[:seq_len, :]
|
||||
if self.mask_padding:
|
||||
@@ -385,17 +388,17 @@ class F5TTS(object):
|
||||
# get text_embed one by one to avoid misalignment
|
||||
text_and_drop_embedding_list = []
|
||||
for i in range(batch):
|
||||
text_and_drop_embedding_i = self.text_embedding(
|
||||
torch.cat(
|
||||
(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
torch.full((1, text_pad_sequence.shape[1]), -1, dtype=torch.int32).to(self.device),
|
||||
),
|
||||
dim=0,
|
||||
),
|
||||
text_embedding_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=False,
|
||||
)
|
||||
text_and_drop_embedding_list.extend([text_and_drop_embedding_i[0], text_and_drop_embedding_i[1]])
|
||||
text_embedding_drop_i = self.text_embedding(
|
||||
text_pad_sequence[i].unsqueeze(0).to(self.device),
|
||||
estimated_reference_target_mel_len[i],
|
||||
drop_text=True,
|
||||
)
|
||||
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
|
||||
|
||||
# pad separately computed text_embed to form batch with max_seq_len
|
||||
text_and_drop_embedding = pad_sequence(
|
||||
|
||||
@@ -229,7 +229,7 @@ class TritonPythonModel:
|
||||
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
|
||||
|
||||
batch = len(requests)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
|
||||
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
|
||||
for i, mel in enumerate(mel_features_list):
|
||||
mel_features[i, : mel.shape[1], :] = mel
|
||||
|
||||
@@ -254,9 +254,9 @@ class TritonPythonModel:
|
||||
|
||||
responses = []
|
||||
for i in range(batch):
|
||||
ref_me_len = reference_mel_len[i]
|
||||
ref_mel_len = reference_mel_len[i]
|
||||
estimated_mel_len = estimated_reference_target_mel_len[i]
|
||||
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
|
||||
audio = self.forward_vocoder(denoised_one_item)
|
||||
if reference_rms_list[i] < self.target_rms:
|
||||
audio = audio * reference_rms_list[i] / self.target_rms
|
||||
|
||||
Reference in New Issue
Block a user