forked from mrq/DL-Art-School
some tfd12 fixes to support multivae
This commit is contained in:
parent
fae05229ec
commit
3f10ce275b
|
@ -425,7 +425,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
|
||||||
mel_partition = truth_mel[:, i*partition_size:(i+1)*partition_size]
|
mel_partition = truth_mel[:, i*partition_size:(i+1)*partition_size]
|
||||||
_, p = q.infer(mel_partition)
|
_, p = q.infer(mel_partition)
|
||||||
proj.append(p.permute(0,2,1))
|
proj.append(p.permute(0,2,1))
|
||||||
proj = torch.cat(proj, dim=1)
|
proj = torch.cat(proj, dim=-1)
|
||||||
|
|
||||||
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||||
return diff
|
return diff
|
||||||
|
@ -551,8 +551,8 @@ def test_vqvae_model():
|
||||||
|
|
||||||
|
|
||||||
def test_multi_vqvae_model():
|
def test_multi_vqvae_model():
|
||||||
clip = torch.randn(2, 100, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
cond = torch.randn(2,80,400)
|
cond = torch.randn(2,256,400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
|
|
||||||
# For music:
|
# For music:
|
||||||
|
@ -563,16 +563,16 @@ def test_multi_vqvae_model():
|
||||||
dropout=.1, vqargs= {
|
dropout=.1, vqargs= {
|
||||||
'positional_dims': 1, 'channels': 64,
|
'positional_dims': 1, 'channels': 64,
|
||||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||||
'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||||
}, num_vaes=4,
|
}, num_vaes=4,
|
||||||
)
|
)
|
||||||
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\',
|
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
|
||||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\',
|
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
|
||||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
|
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
|
||||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
|
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
|
||||||
for i, qfile in enumerate(quants):
|
for i, qfile in enumerate(quants):
|
||||||
quant_weights = torch.load(qfile)
|
quant_weights = torch.load(qfile)
|
||||||
model.quantizer[i].load_state_dict(quant_weights, strict=True)
|
model.quantizers[i].load_state_dict(quant_weights, strict=True)
|
||||||
torch.save(model.state_dict(), 'sample.pth')
|
torch.save(model.state_dict(), 'sample.pth')
|
||||||
|
|
||||||
print_network(model)
|
print_network(model)
|
||||||
|
@ -604,4 +604,4 @@ def test_ar_model():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_vqvae_model()
|
test_multi_vqvae_model()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user