fix codes

This commit is contained in:
James Betker 2022-06-19 21:29:40 -06:00
parent b9f53a3ff9
commit 691ed196da
2 changed files with 16 additions and 4 deletions

View File

@ -596,14 +596,17 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt):
def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
@register_model
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
@register_model
def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt):
return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs'])
@register_model
def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])

View File

@ -1,3 +1,5 @@
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -303,19 +305,19 @@ class DiffusionWaveformGen(nn.Module):
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
def forward(self, x, timesteps, codes, conditioning_free=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
orig_x_shape = x.shape[-1]
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
x, codes = self.fix_alignment(x, codes)
hs = []
@ -325,7 +327,7 @@ class DiffusionWaveformGen(nn.Module):
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
else:
code_emb = self.mel_converter(aligned_conditioning)
code_emb = self.mel_converter(codes)
time_emb = time_emb.float()
h = x
@ -353,6 +355,13 @@ class DiffusionWaveformGen(nn.Module):
return out[:, :, :orig_x_shape]
def before_step(self, step):
# The middle block traditionally gets really small gradients; scale them up by an order of magnitude.
scaled_grad_parameters = self.middle_block.parameters()
for p in scaled_grad_parameters:
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= 10
@register_model
def register_unet_diffusion_waveform_gen3(opt_net, opt):