mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
Merge pull request #1269 from ZhikangNiu/main
feat:add mmdit flash attn support fix: autocast when use flash_attn to enable log_sample
This commit is contained in:
@@ -97,6 +97,8 @@ class MMDiT(nn.Module):
|
||||
text_mask_padding=True,
|
||||
qk_norm=None,
|
||||
checkpoint_activations=False,
|
||||
attn_backend="torch",
|
||||
attn_mask_enabled=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -120,6 +122,8 @@ class MMDiT(nn.Module):
|
||||
ff_mult=ff_mult,
|
||||
context_pre_only=i == depth - 1,
|
||||
qk_norm=qk_norm,
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
@@ -197,6 +201,7 @@ class MMDiT(nn.Module):
|
||||
|
||||
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
c_mask = (text + 1) != 0 # True = valid, False = padding (-1 tokens)
|
||||
if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d
|
||||
x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
|
||||
x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
|
||||
@@ -204,6 +209,7 @@ class MMDiT(nn.Module):
|
||||
c = torch.cat((c_cond, c_uncond), dim=0)
|
||||
t = torch.cat((t, t), dim=0)
|
||||
mask = torch.cat((mask, mask), dim=0) if mask is not None else None
|
||||
c_mask = torch.cat((c_mask, c_mask), dim=0)
|
||||
else:
|
||||
x, c = self.get_input_embed(
|
||||
x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache
|
||||
@@ -217,10 +223,10 @@ class MMDiT(nn.Module):
|
||||
for block in self.transformer_blocks:
|
||||
if self.checkpoint_activations:
|
||||
c, x = torch.utils.checkpoint.checkpoint(
|
||||
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, use_reentrant=False
|
||||
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, c_mask, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
||||
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text, c_mask=c_mask)
|
||||
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x)
|
||||
|
||||
@@ -428,9 +428,10 @@ class Attention(nn.Module):
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
c_mask: bool["b nt"] | None = None, # text mask
|
||||
) -> torch.Tensor:
|
||||
if c is not None:
|
||||
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask)
|
||||
else:
|
||||
return self.processor(self, x, mask=mask, rope=rope)
|
||||
|
||||
@@ -549,8 +550,16 @@ class AttnProcessor:
|
||||
|
||||
|
||||
class JointAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
attn_backend: str = "torch", # "torch" or "flash_attn"
|
||||
attn_mask_enabled: bool = True,
|
||||
):
|
||||
if attn_backend == "flash_attn":
|
||||
assert is_package_available("flash_attn"), "Please install flash-attn first."
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
self.attn_mask_enabled = attn_mask_enabled
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -560,8 +569,10 @@ class JointAttnProcessor:
|
||||
mask: bool["b n"] | None = None,
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
c_mask: bool["b nt"] | None = None, # text mask
|
||||
) -> torch.FloatTensor:
|
||||
residual = x
|
||||
audio_mask = mask
|
||||
|
||||
batch_size = c.shape[0]
|
||||
|
||||
@@ -612,16 +623,48 @@ class JointAttnProcessor:
|
||||
key = torch.cat([key, c_key], dim=2)
|
||||
value = torch.cat([value, c_value], dim=2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
# build combined mask for joint attention: audio mask + text mask
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
if c_mask is not None:
|
||||
mask = torch.cat([mask, c_mask], dim=1)
|
||||
else:
|
||||
mask = F.pad(mask, (0, c.shape[1]), value=True)
|
||||
|
||||
if self.attn_backend == "torch":
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
attn_mask = mask
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
elif self.attn_backend == "flash_attn":
|
||||
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
if self.attn_mask_enabled and mask is not None:
|
||||
total_seq_len = query.shape[1]
|
||||
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask)
|
||||
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
||||
value, _, _, _, _ = unpad_input(value, mask)
|
||||
x = flash_attn_varlen_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_cu_seqlens,
|
||||
k_cu_seqlens,
|
||||
q_max_seqlen_in_batch,
|
||||
k_max_seqlen_in_batch,
|
||||
)
|
||||
x = pad_input(x, indices, batch_size, total_seq_len)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
else:
|
||||
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
||||
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
@@ -637,10 +680,10 @@ class JointAttnProcessor:
|
||||
if not attn.context_pre_only:
|
||||
c = attn.to_out_c(c)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||
if audio_mask is not None:
|
||||
x = x.masked_fill(~audio_mask.unsqueeze(-1), 0.0)
|
||||
if c_mask is not None:
|
||||
c = c.masked_fill(~c_mask.unsqueeze(-1), 0.0)
|
||||
|
||||
return x, c
|
||||
|
||||
@@ -711,7 +754,17 @@ class MMDiTBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None
|
||||
self,
|
||||
dim,
|
||||
heads,
|
||||
dim_head,
|
||||
ff_mult=4,
|
||||
dropout=0.1,
|
||||
context_dim=None,
|
||||
context_pre_only=False,
|
||||
qk_norm=None,
|
||||
attn_backend="torch",
|
||||
attn_mask_enabled=False,
|
||||
):
|
||||
super().__init__()
|
||||
if context_dim is None:
|
||||
@@ -721,7 +774,10 @@ class MMDiTBlock(nn.Module):
|
||||
self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim)
|
||||
self.attn_norm_x = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
processor=JointAttnProcessor(),
|
||||
processor=JointAttnProcessor(
|
||||
attn_backend=attn_backend,
|
||||
attn_mask_enabled=attn_mask_enabled,
|
||||
),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
@@ -740,7 +796,9 @@ class MMDiTBlock(nn.Module):
|
||||
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
||||
def forward(
|
||||
self, x, c, t, mask=None, rope=None, c_rope=None, c_mask=None
|
||||
): # x: noised input, c: context, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
if self.context_pre_only:
|
||||
norm_c = self.attn_norm_c(c, t)
|
||||
@@ -749,7 +807,7 @@ class MMDiTBlock(nn.Module):
|
||||
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
||||
|
||||
# attention
|
||||
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask)
|
||||
|
||||
# process attention output for context c
|
||||
if self.context_pre_only:
|
||||
|
||||
@@ -409,7 +409,7 @@ class Trainer:
|
||||
infer_text = [
|
||||
text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
|
||||
]
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode(), self.accelerator.autocast():
|
||||
generated, _ = self.accelerator.unwrap_model(self.model).sample(
|
||||
cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
|
||||
text=infer_text,
|
||||
|
||||
Reference in New Issue
Block a user