From 84469f35380b0b8c528d11479dd412d9b73851a1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 10 Jun 2022 10:50:34 -0600 Subject: [PATCH] get rid of encoder checkpointing --- codes/models/audio/music/music_quantizer2.py | 22 +++++++++++-------- .../audio/music/transformer_diffusion8.py | 7 +++--- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/codes/models/audio/music/music_quantizer2.py b/codes/models/audio/music/music_quantizer2.py index 771d0305..d7de5ad5 100644 --- a/codes/models/audio/music/music_quantizer2.py +++ b/codes/models/audio/music/music_quantizer2.py @@ -45,7 +45,7 @@ class Upsample(nn.Module): class ResBlock(nn.Module): - def __init__(self, chan): + def __init__(self, chan, checkpoint=True): super().__init__() self.net = nn.Sequential( nn.Conv1d(chan, chan, 3, padding = 1), @@ -56,9 +56,13 @@ class ResBlock(nn.Module): nn.SiLU(), zero_module(nn.Conv1d(chan, chan, 3, padding = 1)), ) + self.checkpoint = checkpoint def forward(self, x): - return checkpoint(self._forward, x) + x + if self.checkpoint: + return checkpoint(self._forward, x) + x + else: + return self._forward(x) + x def _forward(self, x): return self.net(x) @@ -165,7 +169,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class MusicQuantizer2(nn.Module): def __init__(self, inp_channels=256, inner_dim=1024, codevector_dim=1024, down_steps=2, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, - codebook_size=16, codebook_groups=4, + codebook_size=16, codebook_groups=4, checkpoint=True, # Downsample args: expressive_downsamples=False): super().__init__() @@ -191,14 +195,14 @@ class MusicQuantizer2(nn.Module): self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] + [nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)]) - self.encoder = nn.Sequential(ResBlock(inner_dim[0]), - ResBlock(inner_dim[0]), - ResBlock(inner_dim[0])) + self.encoder = nn.Sequential(ResBlock(inner_dim[0], checkpoint=checkpoint), + ResBlock(inner_dim[0], checkpoint=checkpoint), + ResBlock(inner_dim[0], checkpoint=checkpoint)) self.enc_norm = nn.LayerNorm(inner_dim[0], eps=1e-5) self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim[0], kernel_size=3, padding=1), - ResBlock(inner_dim[0]), - ResBlock(inner_dim[0]), - ResBlock(inner_dim[0])) + ResBlock(inner_dim[0], checkpoint=checkpoint), + ResBlock(inner_dim[0], checkpoint=checkpoint), + ResBlock(inner_dim[0], checkpoint=checkpoint)) self.codes = torch.zeros((3000000,), dtype=torch.long) self.internal_step = 0 diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 07dd83db..ab45fb0f 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -203,7 +203,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.freeze_quantizer_until = freeze_quantizer_until self.diff = TransformerDiffusion(**kwargs) self.quantizer = MusicQuantizer2(inp_channels=kwargs['in_channels'], inner_dim=quantizer_dims, - codevector_dim=quantizer_dims[0], + codevector_dim=quantizer_dims[0], checkpoint=False, codebook_size=256, codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature @@ -219,14 +219,13 @@ class TransformerDiffusionWithQuantizer(nn.Module): ) def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): - quant_grad_enabled = self.internal_step > self.freeze_quantizer_until - mse, diversity_loss, proj = self.quantizer(truth_mel, return_decoder_latent=True) proj = proj.permute(0,2,1) - # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. + quant_grad_enabled = self.internal_step > self.freeze_quantizer_until if not quant_grad_enabled: proj = proj.detach() + # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. unused = 0 for p in self.quantizer.parameters(): unused = unused + p.mean() * 0