diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 0e29b79..4e49b7b 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -51,43 +51,38 @@ class TextEmbedding(nn.Module): else: self.extra_modeling = False - def average_upsample_text_by_mask(self, text, text_mask, audio_mask): + def average_upsample_text_by_mask(self, text, text_mask): batch, text_len, text_dim = text.shape + assert batch == 1 - if audio_mask is None: - audio_mask = torch.ones_like(text_mask, dtype=torch.bool) - valid_mask = audio_mask & text_mask - audio_lens = audio_mask.sum(dim=1) # [batch] - valid_lens = valid_mask.sum(dim=1) # [batch] + valid_mask = text_mask[0] + audio_len = text_len + valid_len = valid_mask.sum().item() + + if valid_len == 0: + return torch.zeros_like(text) upsampled_text = torch.zeros_like(text) - for i in range(batch): - audio_len = audio_lens[i].item() - valid_len = valid_lens[i].item() - - if valid_len == 0: - continue - - valid_ind = torch.where(valid_mask[i])[0] - valid_data = text[i, valid_ind, :] # [valid_len, text_dim] - - base_repeat = audio_len // valid_len - remainder = audio_len % valid_len - - indices = [] - for j in range(valid_len): - repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0) - indices.extend([j] * repeat_count) - - indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long) - upsampled = valid_data[indices] # [audio_len, text_dim] - - upsampled_text[i, :audio_len, :] = upsampled + valid_ind = torch.where(valid_mask)[0] + valid_data = text[0, valid_ind, :] # [valid_len, text_dim] + + base_repeat = audio_len // valid_len + remainder = audio_len % valid_len + + indices = [] + for j in range(valid_len): + repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0) + indices.extend([j] * repeat_count) + + indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long) + upsampled = valid_data[indices] # [audio_len, text_dim] + + upsampled_text[0, :audio_len, :] = upsampled return upsampled_text - def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): + def forward(self, text: int["b nt"], seq_len, drop_text=False): text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling: @@ -114,7 +109,7 @@ class TextEmbedding(nn.Module): text = self.text_blocks(text) if self.average_upsampling: - text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask) + text = self.average_upsample_text_by_mask(text, ~text_mask) return text @@ -247,17 +242,16 @@ class DiT(nn.Module): ): if self.text_uncond is None or self.text_cond is None or not cache: if audio_mask is None: - text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask) + text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text) else: batch = x.shape[0] - seq_lens = audio_mask.sum(dim=1) + seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample text_embed_list = [] for i in range(batch): text_embed_i = self.text_embed( text[i].unsqueeze(0), - seq_lens[i].item(), + seq_len=seq_lens[i].item(), drop_text=drop_text, - audio_mask=audio_mask, ) text_embed_list.append(text_embed_i[0]) text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0) @@ -331,4 +325,4 @@ class DiT(nn.Module): x = self.norm_out(x, t) output = self.proj_out(x) - return output + return output \ No newline at end of file