add checkpointing

This commit is contained in:
James Betker 2022-05-02 00:07:42 -06:00
parent ae5f934ea1
commit 3b074aac34

View File

@ -20,6 +20,8 @@ import torch.nn.functional as F
from math import sqrt
from torch.utils.checkpoint import checkpoint
from trainer.networks import register_model
Linear = nn.Linear
@ -151,13 +153,13 @@ class DiffWave(nn.Module):
x = self.input_projection(x)
x = F.relu(x)
timesteps = self.diffusion_embedding(timesteps)
timesteps = checkpoint(self.diffusion_embedding, timesteps)
if self.spectrogram_upsampler: # use conditional model
spectrogram = self.spectrogram_upsampler(spectrogram)
spectrogram = checkpoint(self.spectrogram_upsampler, spectrogram)
skip = None
for layer in self.residual_layers:
x, skip_connection = layer(x, timesteps, spectrogram)
x, skip_connection = checkpoint(layer, x, timesteps, spectrogram)
skip = skip_connection if skip is None else skip_connection + skip
x = skip / sqrt(len(self.residual_layers))