diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index 0e5231c..373773c 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -97,6 +97,8 @@ class MMDiT(nn.Module): text_mask_padding=True, qk_norm=None, checkpoint_activations=False, + attn_backend="torch", + attn_mask_enabled=False, ): super().__init__() @@ -120,6 +122,8 @@ class MMDiT(nn.Module): ff_mult=ff_mult, context_pre_only=i == depth - 1, qk_norm=qk_norm, + attn_backend=attn_backend, + attn_mask_enabled=attn_mask_enabled, ) for i in range(depth) ] @@ -197,6 +201,7 @@ class MMDiT(nn.Module): # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) + c_mask = (text + 1) != 0 # True = valid, False = padding (-1 tokens) if cfg_infer: # pack cond & uncond forward: b n d -> 2b n d x_cond, c_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache) x_uncond, c_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache) @@ -204,6 +209,7 @@ class MMDiT(nn.Module): c = torch.cat((c_cond, c_uncond), dim=0) t = torch.cat((t, t), dim=0) mask = torch.cat((mask, mask), dim=0) if mask is not None else None + c_mask = torch.cat((c_mask, c_mask), dim=0) else: x, c = self.get_input_embed( x, cond, text, drop_audio_cond=drop_audio_cond, drop_text=drop_text, cache=cache @@ -217,10 +223,10 @@ class MMDiT(nn.Module): for block in self.transformer_blocks: if self.checkpoint_activations: c, x = torch.utils.checkpoint.checkpoint( - self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, use_reentrant=False + self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, c_mask, use_reentrant=False ) else: - c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) + c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text, c_mask=c_mask) x = self.norm_out(x, t) output = self.proj_out(x) diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index 274c857..348fc30 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -428,9 +428,10 @@ class Attention(nn.Module): mask: bool["b n"] | None = None, rope=None, # rotary position embedding for x c_rope=None, # rotary position embedding for c + c_mask: bool["b nt"] | None = None, # text mask ) -> torch.Tensor: if c is not None: - return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask) else: return self.processor(self, x, mask=mask, rope=rope) @@ -549,8 +550,16 @@ class AttnProcessor: class JointAttnProcessor: - def __init__(self): - pass + def __init__( + self, + attn_backend: str = "torch", # "torch" or "flash_attn" + attn_mask_enabled: bool = True, + ): + if attn_backend == "flash_attn": + assert is_package_available("flash_attn"), "Please install flash-attn first." + + self.attn_backend = attn_backend + self.attn_mask_enabled = attn_mask_enabled def __call__( self, @@ -560,8 +569,10 @@ class JointAttnProcessor: mask: bool["b n"] | None = None, rope=None, # rotary position embedding for x c_rope=None, # rotary position embedding for c + c_mask: bool["b nt"] | None = None, # text mask ) -> torch.FloatTensor: residual = x + audio_mask = mask batch_size = c.shape[0] @@ -612,16 +623,48 @@ class JointAttnProcessor: key = torch.cat([key, c_key], dim=2) value = torch.cat([value, c_value], dim=2) - # mask. e.g. inference got a batch with different target durations, mask out the padding - if mask is not None: - attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) - attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' - attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) - else: - attn_mask = None + # build combined mask for joint attention: audio mask + text mask + if self.attn_mask_enabled and mask is not None: + if c_mask is not None: + mask = torch.cat([mask, c_mask], dim=1) + else: + mask = F.pad(mask, (0, c.shape[1]), value=True) + + if self.attn_backend == "torch": + # mask. e.g. inference got a batch with different target durations, mask out the padding + if self.attn_mask_enabled and mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2]) + else: + attn_mask = None + x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + elif self.attn_backend == "flash_attn": + query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d] + key = key.transpose(1, 2) + value = value.transpose(1, 2) + if self.attn_mask_enabled and mask is not None: + total_seq_len = query.shape[1] + query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(query, mask) + key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask) + value, _, _, _, _ = unpad_input(value, mask) + x = flash_attn_varlen_func( + query, + key, + value, + q_cu_seqlens, + k_cu_seqlens, + q_max_seqlen_in_batch, + k_max_seqlen_in_batch, + ) + x = pad_input(x, indices, batch_size, total_seq_len) + x = x.reshape(batch_size, -1, attn.heads * head_dim) + else: + x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False) + x = x.reshape(batch_size, -1, attn.heads * head_dim) - x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) - x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) x = x.to(query.dtype) # Split the attention outputs. @@ -637,10 +680,10 @@ class JointAttnProcessor: if not attn.context_pre_only: c = attn.to_out_c(c) - if mask is not None: - mask = mask.unsqueeze(-1) - x = x.masked_fill(~mask, 0.0) - # c = c.masked_fill(~mask, 0.) # no mask for c (text) + if audio_mask is not None: + x = x.masked_fill(~audio_mask.unsqueeze(-1), 0.0) + if c_mask is not None: + c = c.masked_fill(~c_mask.unsqueeze(-1), 0.0) return x, c @@ -711,7 +754,17 @@ class MMDiTBlock(nn.Module): """ def __init__( - self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_dim=None, context_pre_only=False, qk_norm=None + self, + dim, + heads, + dim_head, + ff_mult=4, + dropout=0.1, + context_dim=None, + context_pre_only=False, + qk_norm=None, + attn_backend="torch", + attn_mask_enabled=False, ): super().__init__() if context_dim is None: @@ -721,7 +774,10 @@ class MMDiTBlock(nn.Module): self.attn_norm_c = AdaLayerNorm_Final(context_dim) if context_pre_only else AdaLayerNorm(context_dim) self.attn_norm_x = AdaLayerNorm(dim) self.attn = Attention( - processor=JointAttnProcessor(), + processor=JointAttnProcessor( + attn_backend=attn_backend, + attn_mask_enabled=attn_mask_enabled, + ), dim=dim, heads=heads, dim_head=dim_head, @@ -740,7 +796,9 @@ class MMDiTBlock(nn.Module): self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh") - def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding + def forward( + self, x, c, t, mask=None, rope=None, c_rope=None, c_mask=None + ): # x: noised input, c: context, t: time embedding # pre-norm & modulation for attention input if self.context_pre_only: norm_c = self.attn_norm_c(c, t) @@ -749,7 +807,7 @@ class MMDiTBlock(nn.Module): norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t) # attention - x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope) + x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope, c_mask=c_mask) # process attention output for context c if self.context_pre_only: diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 5073236..3d1e73e 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -409,7 +409,7 @@ class Trainer: infer_text = [ text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0] ] - with torch.inference_mode(): + with torch.inference_mode(), self.accelerator.autocast(): generated, _ = self.accelerator.unwrap_model(self.model).sample( cond=mel_spec[0][:ref_audio_len].unsqueeze(0), text=infer_text,