From 8a1b8e3e6220f4a607ca9658b4f362fb5dee00c5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 31 May 2022 21:09:05 -0600 Subject: [PATCH] add checkpointing --- codes/models/audio/music/music_quantizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index d304ae35..fa95ad62 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from models.arch_util import zero_module from trainer.networks import register_model +from utils.util import checkpoint class Downsample(nn.Module): @@ -42,7 +43,10 @@ class ResBlock(nn.Module): ) def forward(self, x): - return self.net(x) + x + return checkpoint(self._forward, x) + x + + def _forward(self, x): + return self.net(x) class Wav2Vec2GumbelVectorQuantizer(nn.Module):