diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index d304ae35..fa95ad62 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from models.arch_util import zero_module from trainer.networks import register_model +from utils.util import checkpoint class Downsample(nn.Module): @@ -42,7 +43,10 @@ class ResBlock(nn.Module): ) def forward(self, x): - return self.net(x) + x + return checkpoint(self._forward, x) + x + + def _forward(self, x): + return self.net(x) class Wav2Vec2GumbelVectorQuantizer(nn.Module):