From 6b6ce47d2eeabf72982e54fbc41c4674c6e4a363 Mon Sep 17 00:00:00 2001 From: QingyuLiu0521 <2904292256@qq.com> Date: Sun, 15 Feb 2026 21:31:19 -0500 Subject: [PATCH 1/3] Optimize DiT text embedding with batched per-sample seq handling --- src/f5_tts/model/backbones/dit.py | 63 +++++++++++++++++++------------ 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 9f0cb88..2d3f1aa 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -13,7 +13,6 @@ from __future__ import annotations import torch import torch.nn.functional as F from torch import nn -from torch.nn.utils.rnn import pad_sequence from x_transformers.x_transformers import RotaryEmbedding from f5_tts.model.modules import ( @@ -51,18 +50,17 @@ class TextEmbedding(nn.Module): else: self.extra_modeling = False - def average_upsample_text_by_mask(self, text, text_mask): - batch, text_len, text_dim = text.shape - - audio_len = text_len # cuz text already padded to same length as audio sequence + def average_upsample_text_by_mask(self, text, text_mask, target_lens): + batch, max_seq_len, text_dim = text.shape text_lens = text_mask.sum(dim=1) # [batch] upsampled_text = torch.zeros_like(text) for i in range(batch): - text_len = text_lens[i].item() + text_len = int(text_lens[i].item()) + audio_len = int(target_lens[i].item()) - if text_len == 0: + if text_len == 0 or audio_len <= 0: continue valid_ind = torch.where(text_mask[i])[0] @@ -85,8 +83,21 @@ class TextEmbedding(nn.Module): def forward(self, text: int["b nt"], seq_len, drop_text=False): text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens - text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling: + valid_pos_mask = None + if torch.is_tensor(seq_len): + seq_len = seq_len.to(device=text.device, dtype=torch.long) + max_seq_len = int(seq_len.max().item()) + else: + max_seq_len = int(seq_len) + + text = text[:, :max_seq_len] # curtail if character tokens are more than the mel spec tokens + text = F.pad(text, (0, max_seq_len - text.shape[1]), value=0) + + if torch.is_tensor(seq_len): + seq_pos = torch.arange(max_seq_len, device=text.device).unsqueeze(0) + valid_pos_mask = seq_pos < seq_len.unsqueeze(1) + text = text.masked_fill(~valid_pos_mask, 0) + if self.mask_padding: text_mask = text == 0 @@ -94,11 +105,17 @@ class TextEmbedding(nn.Module): text = torch.zeros_like(text) text = self.text_embed(text) # b n -> b n d + if valid_pos_mask is not None: + # Keep short-sample tail strictly zero (equivalent to per-sample pad_sequence(..., 0)). + text = text.masked_fill(~valid_pos_mask.unsqueeze(-1), 0.0) # possible extra modeling if self.extra_modeling: - # sinus pos emb - text = text + self.freqs_cis[:seq_len, :] + # sinus pos emb; for variable seq lengths, only add positions within each sample's valid range. + freqs = self.freqs_cis[:max_seq_len, :] + if valid_pos_mask is not None: + freqs = freqs.unsqueeze(0) * valid_pos_mask.unsqueeze(-1).to(freqs.dtype) + text = text + freqs # convnextv2 blocks if self.mask_padding: @@ -110,7 +127,12 @@ class TextEmbedding(nn.Module): text = self.text_blocks(text) if self.average_upsampling: - text = self.average_upsample_text_by_mask(text, ~text_mask) + if torch.is_tensor(seq_len): + target_lens = seq_len.to(device=text.device, dtype=torch.long) + else: + target_lens = torch.full((text.shape[0],), int(seq_len), device=text.device, dtype=torch.long) + + text = self.average_upsample_text_by_mask(text, ~text_mask, target_lens) return text @@ -243,19 +265,10 @@ class DiT(nn.Module): ): if self.text_uncond is None or self.text_cond is None or not cache: if audio_mask is None: - text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text) + seq_lens = x.shape[1] else: - batch = x.shape[0] - seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample - text_embed_list = [] - for i in range(batch): - text_embed_i = self.text_embed( - text[i].unsqueeze(0), - seq_len=seq_lens[i].item(), - drop_text=drop_text, - ) - text_embed_list.append(text_embed_i[0]) - text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) + seq_lens = audio_mask.sum(dim=1) # per-sample valid speech length + text_embed = self.text_embed(text, seq_lens, drop_text=drop_text) if cache: if drop_text: self.text_uncond = text_embed @@ -326,4 +339,4 @@ class DiT(nn.Module): x = self.norm_out(x, t) output = self.proj_out(x) - return output + return output \ No newline at end of file From 57dc698c16b326a383fcdffcb3ff56a26f2c80bb Mon Sep 17 00:00:00 2001 From: QingyuLiu0521 <2904292256@qq.com> Date: Sun, 15 Feb 2026 21:41:17 -0500 Subject: [PATCH 2/3] Apply ruff formatting --- src/f5_tts/model/backbones/dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 2d3f1aa..5a3e01a 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -339,4 +339,4 @@ class DiT(nn.Module): x = self.norm_out(x, t) output = self.proj_out(x) - return output \ No newline at end of file + return output From c817d6a21d759d2934688d2380acb5b53f00fb61 Mon Sep 17 00:00:00 2001 From: QingyuLiu0521 <2904292256@qq.com> Date: Sun, 15 Feb 2026 23:24:11 -0500 Subject: [PATCH 3/3] Unify seq_len naming in DiT get_input_embed --- src/f5_tts/model/backbones/dit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 5a3e01a..243a0dd 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -265,10 +265,10 @@ class DiT(nn.Module): ): if self.text_uncond is None or self.text_cond is None or not cache: if audio_mask is None: - seq_lens = x.shape[1] + seq_len = x.shape[1] else: - seq_lens = audio_mask.sum(dim=1) # per-sample valid speech length - text_embed = self.text_embed(text, seq_lens, drop_text=drop_text) + seq_len = audio_mask.sum(dim=1) # per-sample valid speech length + text_embed = self.text_embed(text, seq_len=seq_len, drop_text=drop_text) if cache: if drop_text: self.text_uncond = text_embed