fix codes
This commit is contained in:
parent
b9f53a3ff9
commit
691ed196da
|
@ -596,14 +596,17 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt):
|
||||||
def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
|
def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
|
||||||
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
|
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
|
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
|
||||||
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
|
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt):
|
def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt):
|
||||||
return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs'])
|
return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
|
def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
|
||||||
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
|
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import itertools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -303,19 +305,19 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||||
return x, aligned_conditioning
|
return x, aligned_conditioning
|
||||||
|
|
||||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False):
|
def forward(self, x, timesteps, codes, conditioning_free=False):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
:param timesteps: a 1-D batch of timesteps.
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
:param codes: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||||
orig_x_shape = x.shape[-1]
|
orig_x_shape = x.shape[-1]
|
||||||
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
x, codes = self.fix_alignment(x, codes)
|
||||||
|
|
||||||
|
|
||||||
hs = []
|
hs = []
|
||||||
|
@ -325,7 +327,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||||
else:
|
else:
|
||||||
code_emb = self.mel_converter(aligned_conditioning)
|
code_emb = self.mel_converter(codes)
|
||||||
|
|
||||||
time_emb = time_emb.float()
|
time_emb = time_emb.float()
|
||||||
h = x
|
h = x
|
||||||
|
@ -353,6 +355,13 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
|
|
||||||
return out[:, :, :orig_x_shape]
|
return out[:, :, :orig_x_shape]
|
||||||
|
|
||||||
|
def before_step(self, step):
|
||||||
|
# The middle block traditionally gets really small gradients; scale them up by an order of magnitude.
|
||||||
|
scaled_grad_parameters = self.middle_block.parameters()
|
||||||
|
for p in scaled_grad_parameters:
|
||||||
|
if hasattr(p, 'grad') and p.grad is not None:
|
||||||
|
p.grad *= 10
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_unet_diffusion_waveform_gen3(opt_net, opt):
|
def register_unet_diffusion_waveform_gen3(opt_net, opt):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user