diff --git a/model/backbones/dit.py b/model/backbones/dit.py index 9ff5351..b8e6dc3 100644 --- a/model/backbones/dit.py +++ b/model/backbones/dit.py @@ -45,9 +45,9 @@ class TextEmbedding(nn.Module): self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text diff --git a/model/backbones/unett.py b/model/backbones/unett.py index c4ce2c6..ac1d3d3 100644 --- a/model/backbones/unett.py +++ b/model/backbones/unett.py @@ -48,9 +48,9 @@ class TextEmbedding(nn.Module): self.extra_modeling = False def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722 - batch, text_len = text.shape[0], text.shape[1] text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx() text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens + batch, text_len = text.shape[0], text.shape[1] text = F.pad(text, (0, seq_len - text_len), value=0) if drop_text: # cfg for text