diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index 3019bef..0e5231c 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -96,6 +96,7 @@ class MMDiT(nn.Module): text_num_embeds=256, text_mask_padding=True, qk_norm=None, + checkpoint_activations=False, ): super().__init__() @@ -126,6 +127,8 @@ class MMDiT(nn.Module): self.norm_out = AdaLayerNorm_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) + self.checkpoint_activations = checkpoint_activations + self.initialize_weights() def initialize_weights(self): @@ -142,6 +145,13 @@ class MMDiT(nn.Module): nn.init.constant_(self.proj_out.weight, 0) nn.init.constant_(self.proj_out.bias, 0) + def ckpt_wrapper(self, module): + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward + def get_input_embed( self, x, # b n d @@ -205,7 +215,12 @@ class MMDiT(nn.Module): rope_text = self.rotary_embed.forward_from_seq_len(text_len) for block in self.transformer_blocks: - c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) + if self.checkpoint_activations: + c, x = torch.utils.checkpoint.checkpoint( + self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, use_reentrant=False + ) + else: + c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text) x = self.norm_out(x, t) output = self.proj_out(x)