diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 9f0cb88..243a0dd 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -13,7 +13,6 @@ from __future__ import annotations import torch import torch.nn.functional as F from torch import nn -from torch.nn.utils.rnn import pad_sequence from x_transformers.x_transformers import RotaryEmbedding from f5_tts.model.modules import ( @@ -51,18 +50,17 @@ class TextEmbedding(nn.Module): else: self.extra_modeling = False - def average_upsample_text_by_mask(self, text, text_mask): - batch, text_len, text_dim = text.shape - - audio_len = text_len # cuz text already padded to same length as audio sequence + def average_upsample_text_by_mask(self, text, text_mask, target_lens): + batch, max_seq_len, text_dim = text.shape text_lens = text_mask.sum(dim=1) # [batch] upsampled_text = torch.zeros_like(text) for i in range(batch): - text_len = text_lens[i].item() + text_len = int(text_lens[i].item()) + audio_len = int(target_lens[i].item()) - if text_len == 0: + if text_len == 0 or audio_len <= 0: continue valid_ind = torch.where(text_mask[i])[0] @@ -85,8 +83,21 @@ class TextEmbedding(nn.Module): def forward(self, text: int["b nt"], seq_len, drop_text=False): text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() - text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens - text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling: + valid_pos_mask = None + if torch.is_tensor(seq_len): + seq_len = seq_len.to(device=text.device, dtype=torch.long) + max_seq_len = int(seq_len.max().item()) + else: + max_seq_len = int(seq_len) + + text = text[:, :max_seq_len] # curtail if character tokens are more than the mel spec tokens + text = F.pad(text, (0, max_seq_len - text.shape[1]), value=0) + + if torch.is_tensor(seq_len): + seq_pos = torch.arange(max_seq_len, device=text.device).unsqueeze(0) + valid_pos_mask = seq_pos < seq_len.unsqueeze(1) + text = text.masked_fill(~valid_pos_mask, 0) + if self.mask_padding: text_mask = text == 0 @@ -94,11 +105,17 @@ class TextEmbedding(nn.Module): text = torch.zeros_like(text) text = self.text_embed(text) # b n -> b n d + if valid_pos_mask is not None: + # Keep short-sample tail strictly zero (equivalent to per-sample pad_sequence(..., 0)). + text = text.masked_fill(~valid_pos_mask.unsqueeze(-1), 0.0) # possible extra modeling if self.extra_modeling: - # sinus pos emb - text = text + self.freqs_cis[:seq_len, :] + # sinus pos emb; for variable seq lengths, only add positions within each sample's valid range. + freqs = self.freqs_cis[:max_seq_len, :] + if valid_pos_mask is not None: + freqs = freqs.unsqueeze(0) * valid_pos_mask.unsqueeze(-1).to(freqs.dtype) + text = text + freqs # convnextv2 blocks if self.mask_padding: @@ -110,7 +127,12 @@ class TextEmbedding(nn.Module): text = self.text_blocks(text) if self.average_upsampling: - text = self.average_upsample_text_by_mask(text, ~text_mask) + if torch.is_tensor(seq_len): + target_lens = seq_len.to(device=text.device, dtype=torch.long) + else: + target_lens = torch.full((text.shape[0],), int(seq_len), device=text.device, dtype=torch.long) + + text = self.average_upsample_text_by_mask(text, ~text_mask, target_lens) return text @@ -243,19 +265,10 @@ class DiT(nn.Module): ): if self.text_uncond is None or self.text_cond is None or not cache: if audio_mask is None: - text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text) + seq_len = x.shape[1] else: - batch = x.shape[0] - seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample - text_embed_list = [] - for i in range(batch): - text_embed_i = self.text_embed( - text[i].unsqueeze(0), - seq_len=seq_lens[i].item(), - drop_text=drop_text, - ) - text_embed_list.append(text_embed_i[0]) - text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) + seq_len = audio_mask.sum(dim=1) # per-sample valid speech length + text_embed = self.text_embed(text, seq_len=seq_len, drop_text=drop_text) if cache: if drop_text: self.text_uncond = text_embed