pytorch imple. fix batch inference skipping last words in shorter sentences issue #1039 #1179

This commit is contained in:
SWivid
2025-10-24 05:50:25 +00:00
parent f2a4f8581f
commit a051a68552
3 changed files with 29 additions and 14 deletions

View File

@@ -6,6 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# ruff: noqa: F722 F821
from __future__ import annotations
@@ -86,7 +87,7 @@ class TextEmbedding(nn.Module):
return upsampled_text
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None): # noqa: F722
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None):
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
@@ -127,12 +128,19 @@ class InputEmbedding(nn.Module):
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
drop_audio_cond=False,
audio_mask: bool["b n"] | None = None,
):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
x = self.conv_pos_embed(x, mask=audio_mask) + x
return x
@@ -235,7 +243,7 @@ class DiT(nn.Module):
drop_audio_cond: bool = False,
drop_text: bool = False,
cache: bool = True,
audio_mask: bool["b n"] | None = None, # noqa: F722
audio_mask: bool["b n"] | None = None,
):
if self.text_uncond is None or self.text_cond is None or not cache:
if audio_mask is None:
@@ -265,7 +273,7 @@ class DiT(nn.Module):
else:
text_embed = self.text_cond
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond, audio_mask=audio_mask)
return x
@@ -274,11 +282,11 @@ class DiT(nn.Module):
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
mask: bool["b n"] | None = None, # noqa: F722
x: float["b n d"], # nosied input audio
cond: float["b n d"], # masked cond audio
text: int["b nt"], # text
time: float["b"] | float[""], # time step
mask: bool["b n"] | None = None,
drop_audio_cond: bool = False, # cfg for cond audio
drop_text: bool = False, # cfg for text
cfg_infer: bool = False, # cfg inference, pack cond & uncond forward

View File

@@ -6,7 +6,7 @@ nt - text sequence
nw - raw wave length
d - dimension
"""
# flake8: noqa
# ruff: noqa: F722 F821
from __future__ import annotations
@@ -177,14 +177,19 @@ class ConvPositionEmbedding(nn.Module):
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
self.layer_need_mask_idx = [i for i, layer in enumerate(self.conv1d) if isinstance(layer, nn.Conv1d)]
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
if mask is not None:
mask = mask[..., None]
mask = mask.unsqueeze(-1)
mask_t = mask.permute(0, 2, 1)
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
for i, block in enumerate(self.conv1d):
x = block(x)
if mask is not None and i in self.layer_need_mask_idx:
x = x.masked_fill(~mask_t, 0.0)
out = x.permute(0, 2, 1)
if mask is not None:
@@ -435,8 +440,8 @@ class Attention(nn.Module):
# Attention processor
if is_package_available("flash_attn"):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_varlen_func, flash_attn_func
class AttnProcessor:

View File

@@ -307,6 +307,8 @@ def main():
text_mask_padding=pretrained_config["text_mask_padding"],
conv_layers=pretrained_config["conv_layers"],
pe_attn_head=pretrained_config["pe_attn_head"],
# attn_backend="flash_attn", # torch | flash_attn
# attn_mask_enabled=True,
)
model = load_model(DiT, pt_model_config, args.model_path)