mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
Unify seq_len naming in DiT get_input_embed
This commit is contained in:
@@ -265,10 +265,10 @@ class DiT(nn.Module):
|
||||
):
|
||||
if self.text_uncond is None or self.text_cond is None or not cache:
|
||||
if audio_mask is None:
|
||||
seq_lens = x.shape[1]
|
||||
seq_len = x.shape[1]
|
||||
else:
|
||||
seq_lens = audio_mask.sum(dim=1) # per-sample valid speech length
|
||||
text_embed = self.text_embed(text, seq_lens, drop_text=drop_text)
|
||||
seq_len = audio_mask.sum(dim=1) # per-sample valid speech length
|
||||
text_embed = self.text_embed(text, seq_len=seq_len, drop_text=drop_text)
|
||||
if cache:
|
||||
if drop_text:
|
||||
self.text_uncond = text_embed
|
||||
|
||||
Reference in New Issue
Block a user