forked from mrq/DL-Art-School
add checkpointing
This commit is contained in:
parent
ae5f934ea1
commit
3b074aac34
|
@ -20,6 +20,8 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
||||||
Linear = nn.Linear
|
Linear = nn.Linear
|
||||||
|
@ -151,13 +153,13 @@ class DiffWave(nn.Module):
|
||||||
x = self.input_projection(x)
|
x = self.input_projection(x)
|
||||||
x = F.relu(x)
|
x = F.relu(x)
|
||||||
|
|
||||||
timesteps = self.diffusion_embedding(timesteps)
|
timesteps = checkpoint(self.diffusion_embedding, timesteps)
|
||||||
if self.spectrogram_upsampler: # use conditional model
|
if self.spectrogram_upsampler: # use conditional model
|
||||||
spectrogram = self.spectrogram_upsampler(spectrogram)
|
spectrogram = checkpoint(self.spectrogram_upsampler, spectrogram)
|
||||||
|
|
||||||
skip = None
|
skip = None
|
||||||
for layer in self.residual_layers:
|
for layer in self.residual_layers:
|
||||||
x, skip_connection = layer(x, timesteps, spectrogram)
|
x, skip_connection = checkpoint(layer, x, timesteps, spectrogram)
|
||||||
skip = skip_connection if skip is None else skip_connection + skip
|
skip = skip_connection if skip is None else skip_connection + skip
|
||||||
|
|
||||||
x = skip / sqrt(len(self.residual_layers))
|
x = skip / sqrt(len(self.residual_layers))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user