Unify seq_len naming in DiT get_input_embed

This commit is contained in:
QingyuLiu0521
2026-02-15 23:24:11 -05:00
parent 57dc698c16
commit c817d6a21d

View File

@@ -265,10 +265,10 @@ class DiT(nn.Module):
):
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
seq_lens = x.shape[1]
seq_len = x.shape[1]
else:
seq_lens = audio_mask.sum(dim=1) # per-sample valid speech length
text_embed = self.text_embed(text, seq_lens, drop_text=drop_text)
seq_len = audio_mask.sum(dim=1) # per-sample valid speech length
text_embed = self.text_embed(text, seq_len=seq_len, drop_text=drop_text)
if cache:
if drop_text:
self.text_uncond = text_embed