diff --git a/codes/models/audio/music/diffwave.py b/codes/models/audio/music/diffwave.py
index 348721bb..87d02a92 100644
--- a/codes/models/audio/music/diffwave.py
+++ b/codes/models/audio/music/diffwave.py
@@ -20,6 +20,8 @@ import torch.nn.functional as F
 
 from math import sqrt
 
+from torch.utils.checkpoint import checkpoint
+
 from trainer.networks import register_model
 
 Linear = nn.Linear
@@ -151,13 +153,13 @@ class DiffWave(nn.Module):
         x = self.input_projection(x)
         x = F.relu(x)
 
-        timesteps = self.diffusion_embedding(timesteps)
+        timesteps = checkpoint(self.diffusion_embedding, timesteps)
         if self.spectrogram_upsampler:  # use conditional model
-            spectrogram = self.spectrogram_upsampler(spectrogram)
+            spectrogram = checkpoint(self.spectrogram_upsampler, spectrogram)
 
         skip = None
         for layer in self.residual_layers:
-            x, skip_connection = layer(x, timesteps, spectrogram)
+            x, skip_connection = checkpoint(layer, x, timesteps, spectrogram)
             skip = skip_connection if skip is None else skip_connection + skip
 
         x = skip / sqrt(len(self.residual_layers))