add checkpointing

This commit is contained in:
James Betker 2022-05-31 21:09:05 -06:00
parent c0db85bf4f
commit 8a1b8e3e62

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
from models.arch_util import zero_module
from trainer.networks import register_model
from utils.util import checkpoint
class Downsample(nn.Module):
@ -42,7 +43,10 @@ class ResBlock(nn.Module):
)
def forward(self, x):
return self.net(x) + x
return checkpoint(self._forward, x) + x
def _forward(self, x):
return self.net(x)
class Wav2Vec2GumbelVectorQuantizer(nn.Module):