diff --git a/src/f5_tts/configs/F5TTS_Base_train.yaml b/src/f5_tts/configs/F5TTS_Base_train.yaml index d757c5b..fe93332 100644 --- a/src/f5_tts/configs/F5TTS_Base_train.yaml +++ b/src/f5_tts/configs/F5TTS_Base_train.yaml @@ -28,6 +28,7 @@ model: ff_mult: 2 text_dim: 512 conv_layers: 4 + checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 n_mel_channels: 100 diff --git a/src/f5_tts/configs/F5TTS_Small_train.yaml b/src/f5_tts/configs/F5TTS_Small_train.yaml index 833c6af..466ee29 100644 --- a/src/f5_tts/configs/F5TTS_Small_train.yaml +++ b/src/f5_tts/configs/F5TTS_Small_train.yaml @@ -28,6 +28,7 @@ model: ff_mult: 2 text_dim: 512 conv_layers: 4 + checkpoint_activations: False # recompute activations and save memory for extra compute mel_spec: target_sample_rate: 24000 n_mel_channels: 100 diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 391752a..472af28 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -105,6 +105,7 @@ class DiT(nn.Module): text_dim=None, conv_layers=0, long_skip_connection=False, + checkpoint_activations=False, ): super().__init__() @@ -127,6 +128,17 @@ class DiT(nn.Module): self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) + self.checkpoint_activations = checkpoint_activations + + def ckpt_wrapper(self, module): + """Code from https://github.com/chuanyangjin/fast-DiT/blob/1a8ecce58f346f877749f2dc67cdb190d295e4dc/models.py#L233-L237""" + + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward + def forward( self, x: float["b n d"], # nosied input audio # noqa: F722 @@ -152,7 +164,10 @@ class DiT(nn.Module): residual = x for block in self.transformer_blocks: - x = block(x, t, mask=mask, rope=rope) + if self.checkpoint_activations: + x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope) + else: + x = block(x, t, mask=mask, rope=rope) if self.long_skip_connection is not None: x = self.long_skip_connection(torch.cat((x, residual), dim=-1))