add checkpointing
This commit is contained in:
parent
ae5f934ea1
commit
3b074aac34
|
@ -20,6 +20,8 @@ import torch.nn.functional as F
|
|||
|
||||
from math import sqrt
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from trainer.networks import register_model
|
||||
|
||||
Linear = nn.Linear
|
||||
|
@ -151,13 +153,13 @@ class DiffWave(nn.Module):
|
|||
x = self.input_projection(x)
|
||||
x = F.relu(x)
|
||||
|
||||
timesteps = self.diffusion_embedding(timesteps)
|
||||
timesteps = checkpoint(self.diffusion_embedding, timesteps)
|
||||
if self.spectrogram_upsampler: # use conditional model
|
||||
spectrogram = self.spectrogram_upsampler(spectrogram)
|
||||
spectrogram = checkpoint(self.spectrogram_upsampler, spectrogram)
|
||||
|
||||
skip = None
|
||||
for layer in self.residual_layers:
|
||||
x, skip_connection = layer(x, timesteps, spectrogram)
|
||||
x, skip_connection = checkpoint(layer, x, timesteps, spectrogram)
|
||||
skip = skip_connection if skip is None else skip_connection + skip
|
||||
|
||||
x = skip / sqrt(len(self.residual_layers))
|
||||
|
|
Loading…
Reference in New Issue
Block a user