From 691ed196daadb01797c53383f779349834f18164 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 21:29:40 -0600 Subject: [PATCH] fix codes --- .../audio/music/transformer_diffusion12.py | 3 +++ .../audio/music/unet_diffusion_waveform_gen3.py | 17 +++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 208b4c6f..6ebbe8aa 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -596,14 +596,17 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt): def register_transformer_diffusion12_with_ar_prior(opt_net, opt): return TransformerDiffusionWithARPrior(**opt_net['kwargs']) + @register_model def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt): return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs']) + @register_model def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt): return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs']) + @register_model def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt): return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs']) diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen3.py b/codes/models/audio/music/unet_diffusion_waveform_gen3.py index 972c26ea..e75d67fb 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen3.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen3.py @@ -1,3 +1,5 @@ +import itertools + import torch import torch.nn as nn import torch.nn.functional as F @@ -303,19 +305,19 @@ class DiffusionWaveformGen(nn.Module): aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) return x, aligned_conditioning - def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False): + def forward(self, x, timesteps, codes, conditioning_free=False): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. - :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced. :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. :return: an [N x C x ...] Tensor of outputs. """ # Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net. orig_x_shape = x.shape[-1] - x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning) + x, codes = self.fix_alignment(x, codes) hs = [] @@ -325,7 +327,7 @@ class DiffusionWaveformGen(nn.Module): if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) else: - code_emb = self.mel_converter(aligned_conditioning) + code_emb = self.mel_converter(codes) time_emb = time_emb.float() h = x @@ -353,6 +355,13 @@ class DiffusionWaveformGen(nn.Module): return out[:, :, :orig_x_shape] + def before_step(self, step): + # The middle block traditionally gets really small gradients; scale them up by an order of magnitude. + scaled_grad_parameters = self.middle_block.parameters() + for p in scaled_grad_parameters: + if hasattr(p, 'grad') and p.grad is not None: + p.grad *= 10 + @register_model def register_unet_diffusion_waveform_gen3(opt_net, opt):