runtime trtllm: minor fixes. pytorch: update text_embedding logic to correct v0 batching.

This commit is contained in:
SWivid
2025-10-22 00:19:45 +00:00
parent c8bfc3aa3d
commit a0b8fb5df2
3 changed files with 36 additions and 21 deletions

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from x_transformers.x_transformers import RotaryEmbedding
from f5_tts.model.modules import (
@@ -236,19 +237,30 @@ class DiT(nn.Module):
cache: bool = True,
audio_mask: bool["b n"] | None = None, # noqa: F722
):
seq_len = x.shape[1]
# TODO. modify to get text_embed one by one (to avoid misalignment when batching), as done in runtime imple.
if self.text_uncond is None or self.text_cond is None or not cache:
batch = x.shape[0]
seq_lens = audio_mask.sum(dim=1)
text_embed_list = []
for i in range(batch):
text_embed_i = self.text_embed(
text[i].unsqueeze(0),
seq_lens[i].item(),
drop_text=drop_text,
audio_mask=audio_mask,
)
text_embed_list.append(text_embed_i[0])
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
if cache:
if drop_text:
self.text_uncond = text_embed
else:
self.text_cond = text_embed
if cache:
if drop_text:
if self.text_uncond is None:
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
text_embed = self.text_uncond
else:
if self.text_cond is None:
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
text_embed = self.text_cond
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)

View File

@@ -42,13 +42,16 @@ class TextEmbedding(nn.Module):
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
def forward(self, text, seq_len):
def forward(self, text, seq_len, drop_text=False):
text = text + 1
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
if self.mask_padding:
text_mask = text == 0
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
text = text + self.freqs_cis[:seq_len, :]
if self.mask_padding:
@@ -385,17 +388,17 @@ class F5TTS(object):
# get text_embed one by one to avoid misalignment
text_and_drop_embedding_list = []
for i in range(batch):
text_and_drop_embedding_i = self.text_embedding(
torch.cat(
(
text_pad_sequence[i].unsqueeze(0).to(self.device),
torch.full((1, text_pad_sequence.shape[1]), -1, dtype=torch.int32).to(self.device),
),
dim=0,
),
text_embedding_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=False,
)
text_and_drop_embedding_list.extend([text_and_drop_embedding_i[0], text_and_drop_embedding_i[1]])
text_embedding_drop_i = self.text_embedding(
text_pad_sequence[i].unsqueeze(0).to(self.device),
estimated_reference_target_mel_len[i],
drop_text=True,
)
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
# pad separately computed text_embed to form batch with max_seq_len
text_and_drop_embedding = pad_sequence(

View File

@@ -229,7 +229,7 @@ class TritonPythonModel:
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
batch = len(requests)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
for i, mel in enumerate(mel_features_list):
mel_features[i, : mel.shape[1], :] = mel
@@ -254,9 +254,9 @@ class TritonPythonModel:
responses = []
for i in range(batch):
ref_me_len = reference_mel_len[i]
ref_mel_len = reference_mel_len[i]
estimated_mel_len = estimated_reference_target_mel_len[i]
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
audio = self.forward_vocoder(denoised_one_item)
if reference_rms_list[i] < self.target_rms:
audio = audio * reference_rms_list[i] / self.target_rms