mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
Merge pull request #1267 from QingyuLiu0521/qyl/pr-dit-only
Optimize DiT text embedding with batched per-sample seq handling
This commit is contained in:
@@ -13,7 +13,6 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from x_transformers.x_transformers import RotaryEmbedding
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
from f5_tts.model.modules import (
|
from f5_tts.model.modules import (
|
||||||
@@ -51,18 +50,17 @@ class TextEmbedding(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.extra_modeling = False
|
self.extra_modeling = False
|
||||||
|
|
||||||
def average_upsample_text_by_mask(self, text, text_mask):
|
def average_upsample_text_by_mask(self, text, text_mask, target_lens):
|
||||||
batch, text_len, text_dim = text.shape
|
batch, max_seq_len, text_dim = text.shape
|
||||||
|
|
||||||
audio_len = text_len # cuz text already padded to same length as audio sequence
|
|
||||||
text_lens = text_mask.sum(dim=1) # [batch]
|
text_lens = text_mask.sum(dim=1) # [batch]
|
||||||
|
|
||||||
upsampled_text = torch.zeros_like(text)
|
upsampled_text = torch.zeros_like(text)
|
||||||
|
|
||||||
for i in range(batch):
|
for i in range(batch):
|
||||||
text_len = text_lens[i].item()
|
text_len = int(text_lens[i].item())
|
||||||
|
audio_len = int(target_lens[i].item())
|
||||||
|
|
||||||
if text_len == 0:
|
if text_len == 0 or audio_len <= 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
valid_ind = torch.where(text_mask[i])[0]
|
valid_ind = torch.where(text_mask[i])[0]
|
||||||
@@ -85,8 +83,21 @@ class TextEmbedding(nn.Module):
|
|||||||
|
|
||||||
def forward(self, text: int["b nt"], seq_len, drop_text=False):
|
def forward(self, text: int["b nt"], seq_len, drop_text=False):
|
||||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
valid_pos_mask = None
|
||||||
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
|
if torch.is_tensor(seq_len):
|
||||||
|
seq_len = seq_len.to(device=text.device, dtype=torch.long)
|
||||||
|
max_seq_len = int(seq_len.max().item())
|
||||||
|
else:
|
||||||
|
max_seq_len = int(seq_len)
|
||||||
|
|
||||||
|
text = text[:, :max_seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||||
|
text = F.pad(text, (0, max_seq_len - text.shape[1]), value=0)
|
||||||
|
|
||||||
|
if torch.is_tensor(seq_len):
|
||||||
|
seq_pos = torch.arange(max_seq_len, device=text.device).unsqueeze(0)
|
||||||
|
valid_pos_mask = seq_pos < seq_len.unsqueeze(1)
|
||||||
|
text = text.masked_fill(~valid_pos_mask, 0)
|
||||||
|
|
||||||
if self.mask_padding:
|
if self.mask_padding:
|
||||||
text_mask = text == 0
|
text_mask = text == 0
|
||||||
|
|
||||||
@@ -94,11 +105,17 @@ class TextEmbedding(nn.Module):
|
|||||||
text = torch.zeros_like(text)
|
text = torch.zeros_like(text)
|
||||||
|
|
||||||
text = self.text_embed(text) # b n -> b n d
|
text = self.text_embed(text) # b n -> b n d
|
||||||
|
if valid_pos_mask is not None:
|
||||||
|
# Keep short-sample tail strictly zero (equivalent to per-sample pad_sequence(..., 0)).
|
||||||
|
text = text.masked_fill(~valid_pos_mask.unsqueeze(-1), 0.0)
|
||||||
|
|
||||||
# possible extra modeling
|
# possible extra modeling
|
||||||
if self.extra_modeling:
|
if self.extra_modeling:
|
||||||
# sinus pos emb
|
# sinus pos emb; for variable seq lengths, only add positions within each sample's valid range.
|
||||||
text = text + self.freqs_cis[:seq_len, :]
|
freqs = self.freqs_cis[:max_seq_len, :]
|
||||||
|
if valid_pos_mask is not None:
|
||||||
|
freqs = freqs.unsqueeze(0) * valid_pos_mask.unsqueeze(-1).to(freqs.dtype)
|
||||||
|
text = text + freqs
|
||||||
|
|
||||||
# convnextv2 blocks
|
# convnextv2 blocks
|
||||||
if self.mask_padding:
|
if self.mask_padding:
|
||||||
@@ -110,7 +127,12 @@ class TextEmbedding(nn.Module):
|
|||||||
text = self.text_blocks(text)
|
text = self.text_blocks(text)
|
||||||
|
|
||||||
if self.average_upsampling:
|
if self.average_upsampling:
|
||||||
text = self.average_upsample_text_by_mask(text, ~text_mask)
|
if torch.is_tensor(seq_len):
|
||||||
|
target_lens = seq_len.to(device=text.device, dtype=torch.long)
|
||||||
|
else:
|
||||||
|
target_lens = torch.full((text.shape[0],), int(seq_len), device=text.device, dtype=torch.long)
|
||||||
|
|
||||||
|
text = self.average_upsample_text_by_mask(text, ~text_mask, target_lens)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@@ -243,19 +265,10 @@ class DiT(nn.Module):
|
|||||||
):
|
):
|
||||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||||
if audio_mask is None:
|
if audio_mask is None:
|
||||||
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
|
seq_len = x.shape[1]
|
||||||
else:
|
else:
|
||||||
batch = x.shape[0]
|
seq_len = audio_mask.sum(dim=1) # per-sample valid speech length
|
||||||
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
|
text_embed = self.text_embed(text, seq_len=seq_len, drop_text=drop_text)
|
||||||
text_embed_list = []
|
|
||||||
for i in range(batch):
|
|
||||||
text_embed_i = self.text_embed(
|
|
||||||
text[i].unsqueeze(0),
|
|
||||||
seq_len=seq_lens[i].item(),
|
|
||||||
drop_text=drop_text,
|
|
||||||
)
|
|
||||||
text_embed_list.append(text_embed_i[0])
|
|
||||||
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
|
|
||||||
if cache:
|
if cache:
|
||||||
if drop_text:
|
if drop_text:
|
||||||
self.text_uncond = text_embed
|
self.text_uncond = text_embed
|
||||||
|
|||||||
Reference in New Issue
Block a user