Merge pull request #1269 from ZhikangNiu/main

feat:add mmdit flash attn support
fix: autocast when use flash_attn to enable log_sample
This commit is contained in:
Yushen CHEN
2026-02-27 01:20:09 +08:00
committed by GitHub
3 changed files with 87 additions and 23 deletions

View File

@@ -97,6 +97,8 @@ class MMDiT(nn.Module):
text_mask_padding=True,
qk_norm=None,
checkpoint_activations=False,
attn_backend="torch",
attn_mask_enabled=False,
):
super().__init__()
@@ -120,6 +122,8 @@ class MMDiT(nn.Module):
ff_mult=ff_mult,
context_pre_only=i == depth - 1,
qk_norm=qk_norm,
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
)
for i in range(depth)
]
@@ -197,6 +201,7 @@ class MMDiT(nn.Module):
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
c_mask = (text + 1) != 0 # True = valid, False = padding (-1 tokens)
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
@@ -204,6 +209,7 @@ class MMDiT(nn.Module):
c = torch.cat((c_cond, c_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
c_mask = torch.cat((c_mask, c_mask), dim=0)
else:
x, c = self.get_input_embed(
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
@@ -217,10 +223,10 @@ class MMDiT(nn.Module):
for block in self.transformer_blocks:
if self.checkpoint_activations:
c, x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, use_reentrant=False
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, c_mask, use_reentrant=False
)
else:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text, c_mask=c_mask)
x = self.norm_out(x, t)
output = self.proj_out(x)

View File

@@ -428,9 +428,10 @@ class Attention(nn.Module):
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
c_mask: bool["b nt"] | None = None, # text mask
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask)
else:
return self.processor(self, x, mask=mask, rope=rope)
@@ -549,8 +550,16 @@ class AttnProcessor:
class JointAttnProcessor:
def __init__(self):
pass
def __init__(
self,
attn_backend: str = "torch", # "torch" or "flash_attn"
attn_mask_enabled: bool = True,
):
if attn_backend == "flash_attn":
assert is_package_available("flash_attn"), "Please install flash-attn first."
self.attn_backend = attn_backend
self.attn_mask_enabled = attn_mask_enabled
def __call__(
self,
@@ -560,8 +569,10 @@ class JointAttnProcessor:
mask: bool["b n"] | None = None,
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
c_mask: bool["b nt"] | None = None, # text mask
) -> torch.FloatTensor:
residual = x
audio_mask = mask
batch_size = c.shape[0]
@@ -612,16 +623,48 @@ class JointAttnProcessor:
key = torch.cat([key, c_key], dim=2)
value = torch.cat([value, c_value], dim=2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
# build combined mask for joint attention: audio mask + text mask
if self.attn_mask_enabled and mask is not None:
if c_mask is not None:
mask = torch.cat([mask, c_mask], dim=1)
else:
mask = F.pad(mask, (0, c.shape[1]), value=True)
if self.attn_backend == "torch":
# mask. e.g. inference got a batch with different target durations, mask out the padding
if self.attn_mask_enabled and mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif self.attn_backend == "flash_attn":
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if self.attn_mask_enabled and mask is not None:
total_seq_len = query.shape[1]
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
value, _, _, _, _ = unpad_input(value, mask)
x = flash_attn_varlen_func(
query,
key,
value,
q_cu_seqlens,
k_cu_seqlens,
q_max_seqlen_in_batch,
k_max_seqlen_in_batch,
)
x = pad_input(x, indices, batch_size, total_seq_len)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
else:
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
x = x.reshape(batch_size, -1, attn.heads * head_dim)
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
@@ -637,10 +680,10 @@ class JointAttnProcessor:
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
if audio_mask is not None:
x = x.masked_fill(~audio_mask.unsqueeze(-1), 0.0)
if c_mask is not None:
c = c.masked_fill(~c_mask.unsqueeze(-1), 0.0)
return x, c
@@ -711,7 +754,17 @@ class MMDiTBlock(nn.Module):
"""
def __init__(
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
self,
dim,
heads,
dim_head,
ff_mult=4,
dropout=0.1,
context_dim=None,
context_pre_only=False,
qk_norm=None,
attn_backend="torch",
attn_mask_enabled=False,
):
super().__init__()
if context_dim is None:
@@ -721,7 +774,10 @@ class MMDiTBlock(nn.Module):
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
self.attn_norm_x = AdaLayerNorm(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
processor=JointAttnProcessor(
attn_backend=attn_backend,
attn_mask_enabled=attn_mask_enabled,
),
dim=dim,
heads=heads,
dim_head=dim_head,
@@ -740,7 +796,9 @@ class MMDiTBlock(nn.Module):
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
def forward(
self, x, c, t, mask=None, rope=None, c_rope=None, c_mask=None
): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
@@ -749,7 +807,7 @@ class MMDiTBlock(nn.Module):
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask)
# process attention output for context c
if self.context_pre_only:

View File

@@ -409,7 +409,7 @@ class Trainer:
infer_text = [
text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
]
with torch.inference_mode():
with torch.inference_mode(), self.accelerator.autocast():
generated, _ = self.accelerator.unwrap_model(self.model).sample(
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
text=infer_text,