diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index f63f53f1..bd5cd0b5 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -406,6 +406,73 @@ class TransformerDiffusionWithPretrainedVqvae(nn.Module): p.grad *= .2 +class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): + def __init__(self, num_vaes=4, vqargs, **kwargs): + super().__init__() + + self.internal_step = 0 + self.diff = TransformerDiffusion(**kwargs) + self.quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)]) + for p in self.quantizers.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + + def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): + with torch.no_grad(): + proj = [] + partition_size = truth_mel.shape[1] // len(self.quantizers) + for i, q in enumerate(self.quantizers): + mel_partition = truth_mel[:, i*partition_size:(i+1)*partition_size] + _, p = q.infer(mel_partition) + proj.append(p.permute(0,2,1)) + proj = torch.cat(proj, dim=1) + + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + return diff + + def get_debug_values(self, step, __): + if self.quantizers[0].total_codes > 0: + dbgs = {} + for i in range(len(self.quantizers)): + dbgs[f'histogram_quant{i}_codes'] = self.quantizers[i].codes[:self.quantizers[i].total_codes] + return dbgs + else: + return {} + + def get_grad_norm_parameter_groups(self): + attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers])) + attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers])) + ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers])) + ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers])) + blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + groups = { + 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])), + 'blk1_attention_layers': attn1, + 'blk2_attention_layers': attn2, + 'attention_layers': attn1 + attn2, + 'blk1_ff_layers': ff1, + 'blk2_ff_layers': ff2, + 'ff_layers': ff1 + ff2, + 'block_out_layers': blkout_layers, + 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), + 'out': list(self.diff.out.parameters()), + 'x_proj': list(self.diff.inp_block.parameters()), + 'layers': list(self.diff.layers.parameters()), + 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), + 'time_embed': list(self.diff.time_embed.parameters()), + } + return groups + + def before_step(self, step): + scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \ + list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])) + # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes + # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than + # directly fiddling with the gradients. + for p in scaled_grad_parameters: + p.grad *= .2 + + @register_model def register_transformer_diffusion12(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) @@ -424,6 +491,10 @@ def register_transformer_diffusion12_with_ar_prior(opt_net, opt): def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt): return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs']) +@register_model +def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt): + return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs']) + def test_quant_model(): clip = torch.randn(2, 256, 400) @@ -479,6 +550,36 @@ def test_vqvae_model(): pg = model.get_grad_norm_parameter_groups() +def test_multi_vqvae_model(): + clip = torch.randn(2, 100, 400) + cond = torch.randn(2,80,400) + ts = torch.LongTensor([600, 600]) + + # For music: + model = TransformerDiffusionWithMultiPretrainedVqvae(in_channels=256, out_channels=200, + model_channels=1024, contraction_dim=512, + prenet_channels=1024, num_heads=8, + input_vec_dim=2048, num_layers=12, prenet_layers=6, + dropout=.1, vqargs= { + 'positional_dims': 1, 'channels': 64, + 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, + 'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + }, num_vaes=4, + ) + quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth'] + for i, qfile in enumerate(quants): + quant_weights = torch.load(qfile) + model.quantizer[i].load_state_dict(quant_weights, strict=True) + torch.save(model.state_dict(), 'sample.pth') + + print_network(model) + o = model(clip, ts, cond) + pg = model.get_grad_norm_parameter_groups() + + def test_ar_model(): clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400)