mirror of
https://github.com/SWivid/F5-TTS.git
synced 2026-03-12 21:02:50 -07:00
Use torch.utils.checkpoint in mmdit forward loop when enabled to reduce memory usage.
This commit is contained in:
@@ -96,6 +96,7 @@ class MMDiT(nn.Module):
|
||||
text_num_embeds=256,
|
||||
text_mask_padding=True,
|
||||
qk_norm=None,
|
||||
checkpoint_activations=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -126,6 +127,8 @@ class MMDiT(nn.Module):
|
||||
self.norm_out = AdaLayerNorm_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
|
||||
self.checkpoint_activations = checkpoint_activations
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
@@ -142,6 +145,13 @@ class MMDiT(nn.Module):
|
||||
nn.init.constant_(self.proj_out.weight, 0)
|
||||
nn.init.constant_(self.proj_out.bias, 0)
|
||||
|
||||
def ckpt_wrapper(self, module):
|
||||
def ckpt_forward(*inputs):
|
||||
outputs = module(*inputs)
|
||||
return outputs
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def get_input_embed(
|
||||
self,
|
||||
x, # b n d
|
||||
@@ -205,7 +215,12 @@ class MMDiT(nn.Module):
|
||||
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
||||
if self.checkpoint_activations:
|
||||
c, x = torch.utils.checkpoint.checkpoint(
|
||||
self.ckpt_wrapper(block), x, c, t, mask, rope_audio, rope_text, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
||||
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x)
|
||||
|
||||
Reference in New Issue
Block a user