forked from mrq/DL-Art-School
add checkpointing
This commit is contained in:
parent
c0db85bf4f
commit
8a1b8e3e62
|
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from models.arch_util import zero_module
|
from models.arch_util import zero_module
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
|
@ -42,7 +43,10 @@ class ResBlock(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x) + x
|
return checkpoint(self._forward, x) + x
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user