some tfd12 fixes to support multivae

This commit is contained in:
James Betker 2022-06-14 23:53:50 -06:00
parent fae05229ec
commit 3f10ce275b

View File

@ -425,7 +425,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
mel_partition = truth_mel[:, i*partition_size:(i+1)*partition_size]
_, p = q.infer(mel_partition)
proj.append(p.permute(0,2,1))
proj = torch.cat(proj, dim=1)
proj = torch.cat(proj, dim=-1)
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
return diff
@ -551,8 +551,8 @@ def test_vqvae_model():
def test_multi_vqvae_model():
clip = torch.randn(2, 100, 400)
cond = torch.randn(2,80,400)
clip = torch.randn(2, 256, 400)
cond = torch.randn(2,256,400)
ts = torch.LongTensor([600, 600])
# For music:
@ -563,16 +563,16 @@ def test_multi_vqvae_model():
dropout=.1, vqargs= {
'positional_dims': 1, 'channels': 64,
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
}, num_vaes=4,
)
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\',
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
for i, qfile in enumerate(quants):
quant_weights = torch.load(qfile)
model.quantizer[i].load_state_dict(quant_weights, strict=True)
model.quantizers[i].load_state_dict(quant_weights, strict=True)
torch.save(model.state_dict(), 'sample.pth')
print_network(model)
@ -604,4 +604,4 @@ def test_ar_model():
if __name__ == '__main__':
test_vqvae_model()
test_multi_vqvae_model()