From 47330d603b459f131b5204f7f47a597d47174c7e Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 13 Jun 2022 11:19:33 -0600 Subject: [PATCH] Pretrained vqvae option for tfd12.. --- .../audio/music/transformer_diffusion12.py | 93 ++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 43d29421..0359767b 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from models.arch_util import ResBlock from models.audio.music.music_quantizer2 import MusicQuantizer2 +from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \ @@ -304,7 +305,6 @@ class TransformerDiffusionWithQuantizer(nn.Module): p.grad *= .2 - class TransformerDiffusionWithARPrior(nn.Module): def __init__(self, freeze_diff=False, **kwargs): super().__init__() @@ -346,6 +346,66 @@ class TransformerDiffusionWithARPrior(nn.Module): return diff +class TransformerDiffusionWithPretrainedVqvae(nn.Module): + def __init__(self, vqargs, **kwargs): + super().__init__() + + self.internal_step = 0 + self.diff = TransformerDiffusion(**kwargs) + self.quantizer = DiscreteVAE(**vqargs) + self.quantizer = self.quantizer.eval() + for p in self.quantizer.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + + def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): + with torch.no_grad(): + reconstructed, proj = self.quantizer.infer(truth_mel) + proj = proj.permute(0,2,1) + + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + return diff + + def get_debug_values(self, step, __): + if self.quantizer.total_codes > 0: + return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes]} + else: + return {} + + def get_grad_norm_parameter_groups(self): + attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers])) + attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers])) + ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers])) + ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers])) + blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + groups = { + 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])), + 'blk1_attention_layers': attn1, + 'blk2_attention_layers': attn2, + 'attention_layers': attn1 + attn2, + 'blk1_ff_layers': ff1, + 'blk2_ff_layers': ff2, + 'ff_layers': ff1 + ff2, + 'block_out_layers': blkout_layers, + 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), + 'out': list(self.diff.out.parameters()), + 'x_proj': list(self.diff.inp_block.parameters()), + 'layers': list(self.diff.layers.parameters()), + 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), + 'time_embed': list(self.diff.time_embed.parameters()), + } + return groups + + def before_step(self, step): + scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \ + list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])) + # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes + # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than + # directly fiddling with the gradients. + for p in scaled_grad_parameters: + p.grad *= .2 + + @register_model def register_transformer_diffusion12(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) @@ -360,6 +420,10 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt): def register_transformer_diffusion12_with_ar_prior(opt_net, opt): return TransformerDiffusionWithARPrior(**opt_net['kwargs']) +@register_model +def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt): + return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs']) + def test_quant_model(): clip = torch.randn(2, 256, 400) @@ -390,6 +454,31 @@ def test_quant_model(): print(t) +def test_vqvae_model(): + clip = torch.randn(2, 100, 400) + cond = torch.randn(2,80,400) + ts = torch.LongTensor([600, 600]) + + # For music: + model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200, + model_channels=1024, contraction_dim=512, + prenet_channels=1024, num_heads=8, + input_vec_dim=512, num_layers=12, prenet_layers=6, + dropout=.1, vqargs= { + 'positional_dims': 1, 'channels': 80, + 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, + 'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + } + ) + quant_weights = torch.load('D:\\dlas\\experiments\\retrained_dvae_8192_clips.pth') + model.quantizer.load_state_dict(quant_weights, strict=True) + #torch.save(model.state_dict(), 'sample.pth') + + print_network(model) + o = model(clip, ts, cond) + pg = model.get_grad_norm_parameter_groups() + + def test_ar_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) @@ -414,4 +503,4 @@ def test_ar_model(): if __name__ == '__main__': - test_quant_model() + test_vqvae_model()