From 3b074aac343540deff41237b51d246352f191b10 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 2 May 2022 00:07:42 -0600 Subject: [PATCH] add checkpointing --- codes/models/audio/music/diffwave.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/codes/models/audio/music/diffwave.py b/codes/models/audio/music/diffwave.py index 348721bb..87d02a92 100644 --- a/codes/models/audio/music/diffwave.py +++ b/codes/models/audio/music/diffwave.py @@ -20,6 +20,8 @@ import torch.nn.functional as F from math import sqrt +from torch.utils.checkpoint import checkpoint + from trainer.networks import register_model Linear = nn.Linear @@ -151,13 +153,13 @@ class DiffWave(nn.Module): x = self.input_projection(x) x = F.relu(x) - timesteps = self.diffusion_embedding(timesteps) + timesteps = checkpoint(self.diffusion_embedding, timesteps) if self.spectrogram_upsampler: # use conditional model - spectrogram = self.spectrogram_upsampler(spectrogram) + spectrogram = checkpoint(self.spectrogram_upsampler, spectrogram) skip = None for layer in self.residual_layers: - x, skip_connection = layer(x, timesteps, spectrogram) + x, skip_connection = checkpoint(layer, x, timesteps, spectrogram) skip = skip_connection if skip is None else skip_connection + skip x = skip / sqrt(len(self.residual_layers))