diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index f82a267d..128f5168 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -425,7 +425,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): mel_partition = truth_mel[:, i*partition_size:(i+1)*partition_size] _, p = q.infer(mel_partition) proj.append(p.permute(0,2,1)) - proj = torch.cat(proj, dim=1) + proj = torch.cat(proj, dim=-1) diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) return diff @@ -551,8 +551,8 @@ def test_vqvae_model(): def test_multi_vqvae_model(): - clip = torch.randn(2, 100, 400) - cond = torch.randn(2,80,400) + clip = torch.randn(2, 256, 400) + cond = torch.randn(2,256,400) ts = torch.LongTensor([600, 600]) # For music: @@ -563,16 +563,16 @@ def test_multi_vqvae_model(): dropout=.1, vqargs= { 'positional_dims': 1, 'channels': 64, 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, - 'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, }, num_vaes=4, ) - quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\', - 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\', + quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth', 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth', 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth'] for i, qfile in enumerate(quants): quant_weights = torch.load(qfile) - model.quantizer[i].load_state_dict(quant_weights, strict=True) + model.quantizers[i].load_state_dict(quant_weights, strict=True) torch.save(model.state_dict(), 'sample.pth') print_network(model) @@ -604,4 +604,4 @@ def test_ar_model(): if __name__ == '__main__': - test_vqvae_model() + test_multi_vqvae_model()