forked from mrq/DL-Art-School
fix codes
This commit is contained in:
parent
b9f53a3ff9
commit
691ed196da
|
@ -596,14 +596,17 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt):
|
|||
def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
|
||||
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
|
||||
|
||||
|
||||
@register_model
|
||||
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
|
||||
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
|
||||
|
||||
|
||||
@register_model
|
||||
def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt):
|
||||
return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs'])
|
||||
|
||||
|
||||
@register_model
|
||||
def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
|
||||
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -303,19 +305,19 @@ class DiffusionWaveformGen(nn.Module):
|
|||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
|
||||
def forward(self, x, timesteps, codes, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||
orig_x_shape = x.shape[-1]
|
||||
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
||||
x, codes = self.fix_alignment(x, codes)
|
||||
|
||||
|
||||
hs = []
|
||||
|
@ -325,7 +327,7 @@ class DiffusionWaveformGen(nn.Module):
|
|||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
else:
|
||||
code_emb = self.mel_converter(aligned_conditioning)
|
||||
code_emb = self.mel_converter(codes)
|
||||
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
|
@ -353,6 +355,13 @@ class DiffusionWaveformGen(nn.Module):
|
|||
|
||||
return out[:, :, :orig_x_shape]
|
||||
|
||||
def before_step(self, step):
|
||||
# The middle block traditionally gets really small gradients; scale them up by an order of magnitude.
|
||||
scaled_grad_parameters = self.middle_block.parameters()
|
||||
for p in scaled_grad_parameters:
|
||||
if hasattr(p, 'grad') and p.grad is not None:
|
||||
p.grad *= 10
|
||||
|
||||
|
||||
@register_model
|
||||
def register_unet_diffusion_waveform_gen3(opt_net, opt):
|
||||
|
|
Loading…
Reference in New Issue
Block a user