forked from mrq/DL-Art-School
some tfd12 fixes to support multivae
This commit is contained in:
parent
fae05229ec
commit
3f10ce275b
|
@ -425,7 +425,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
|
|||
mel_partition = truth_mel[:, i*partition_size:(i+1)*partition_size]
|
||||
_, p = q.infer(mel_partition)
|
||||
proj.append(p.permute(0,2,1))
|
||||
proj = torch.cat(proj, dim=1)
|
||||
proj = torch.cat(proj, dim=-1)
|
||||
|
||||
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||
return diff
|
||||
|
@ -551,8 +551,8 @@ def test_vqvae_model():
|
|||
|
||||
|
||||
def test_multi_vqvae_model():
|
||||
clip = torch.randn(2, 100, 400)
|
||||
cond = torch.randn(2,80,400)
|
||||
clip = torch.randn(2, 256, 400)
|
||||
cond = torch.randn(2,256,400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
|
||||
# For music:
|
||||
|
@ -563,16 +563,16 @@ def test_multi_vqvae_model():
|
|||
dropout=.1, vqargs= {
|
||||
'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||
}, num_vaes=4,
|
||||
)
|
||||
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\',
|
||||
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
|
||||
for i, qfile in enumerate(quants):
|
||||
quant_weights = torch.load(qfile)
|
||||
model.quantizer[i].load_state_dict(quant_weights, strict=True)
|
||||
model.quantizers[i].load_state_dict(quant_weights, strict=True)
|
||||
torch.save(model.state_dict(), 'sample.pth')
|
||||
|
||||
print_network(model)
|
||||
|
@ -604,4 +604,4 @@ def test_ar_model():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_vqvae_model()
|
||||
test_multi_vqvae_model()
|
||||
|
|
Loading…
Reference in New Issue
Block a user