Use torch.utils.checkpoint in mmdit forward loop when enabled to reduce memory usage.

This commit is contained in:
ZhikangNiu
2026-02-14 11:05:08 +08:00
parent 655fbca552
commit bb5526fc5b

View File

@@ -96,6 +96,7 @@ class MMDiT(nn.Module):
text_num_embeds=256,
text_mask_padding=True,
qk_norm=None,
checkpoint_activations=False,
):
super().__init__()
@@ -126,6 +127,8 @@ class MMDiT(nn.Module):
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
self.initialize_weights()
def initialize_weights(self):
@@ -142,6 +145,13 @@ class MMDiT(nn.Module):
nn.init.constant_(self.proj_out.weight, 0)
nn.init.constant_(self.proj_out.bias, 0)
def ckpt_wrapper(self, module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def get_input_embed(
self,
x, # b n d
@@ -205,7 +215,12 @@ class MMDiT(nn.Module):
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
for block in self.transformer_blocks:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
if self.checkpoint_activations:
c, x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, use_reentrant=False
)
else:
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
x = self.norm_out(x, t)
output = self.proj_out(x)