forked from mrq/DL-Art-School
multivqvae tfd12
This commit is contained in:
parent
d29ea0df5e
commit
6bc19d1328
|
@ -406,6 +406,73 @@ class TransformerDiffusionWithPretrainedVqvae(nn.Module):
|
|||
p.grad *= .2
|
||||
|
||||
|
||||
class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
|
||||
def __init__(self, num_vaes=4, vqargs, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.internal_step = 0
|
||||
self.diff = TransformerDiffusion(**kwargs)
|
||||
self.quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
|
||||
for p in self.quantizers.parameters():
|
||||
p.DO_NOT_TRAIN = True
|
||||
p.requires_grad = False
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||
with torch.no_grad():
|
||||
proj = []
|
||||
partition_size = truth_mel.shape[1] // len(self.quantizers)
|
||||
for i, q in enumerate(self.quantizers):
|
||||
mel_partition = truth_mel[:, i*partition_size:(i+1)*partition_size]
|
||||
_, p = q.infer(mel_partition)
|
||||
proj.append(p.permute(0,2,1))
|
||||
proj = torch.cat(proj, dim=1)
|
||||
|
||||
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||
return diff
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.quantizers[0].total_codes > 0:
|
||||
dbgs = {}
|
||||
for i in range(len(self.quantizers)):
|
||||
dbgs[f'histogram_quant{i}_codes'] = self.quantizers[i].codes[:self.quantizers[i].total_codes]
|
||||
return dbgs
|
||||
else:
|
||||
return {}
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers]))
|
||||
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers]))
|
||||
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers]))
|
||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
|
||||
groups = {
|
||||
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])),
|
||||
'blk1_attention_layers': attn1,
|
||||
'blk2_attention_layers': attn2,
|
||||
'attention_layers': attn1 + attn2,
|
||||
'blk1_ff_layers': ff1,
|
||||
'blk2_ff_layers': ff2,
|
||||
'ff_layers': ff1 + ff2,
|
||||
'block_out_layers': blkout_layers,
|
||||
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
|
||||
'out': list(self.diff.out.parameters()),
|
||||
'x_proj': list(self.diff.inp_block.parameters()),
|
||||
'layers': list(self.diff.layers.parameters()),
|
||||
'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()),
|
||||
'time_embed': list(self.diff.time_embed.parameters()),
|
||||
}
|
||||
return groups
|
||||
|
||||
def before_step(self, step):
|
||||
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
||||
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
|
||||
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||
# directly fiddling with the gradients.
|
||||
for p in scaled_grad_parameters:
|
||||
p.grad *= .2
|
||||
|
||||
|
||||
@register_model
|
||||
def register_transformer_diffusion12(opt_net, opt):
|
||||
return TransformerDiffusion(**opt_net['kwargs'])
|
||||
|
@ -424,6 +491,10 @@ def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
|
|||
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
|
||||
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
|
||||
|
||||
@register_model
|
||||
def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt):
|
||||
return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs'])
|
||||
|
||||
|
||||
def test_quant_model():
|
||||
clip = torch.randn(2, 256, 400)
|
||||
|
@ -479,6 +550,36 @@ def test_vqvae_model():
|
|||
pg = model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
||||
def test_multi_vqvae_model():
|
||||
clip = torch.randn(2, 100, 400)
|
||||
cond = torch.randn(2,80,400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
|
||||
# For music:
|
||||
model = TransformerDiffusionWithMultiPretrainedVqvae(in_channels=256, out_channels=200,
|
||||
model_channels=1024, contraction_dim=512,
|
||||
prenet_channels=1024, num_heads=8,
|
||||
input_vec_dim=2048, num_layers=12, prenet_layers=6,
|
||||
dropout=.1, vqargs= {
|
||||
'positional_dims': 1, 'channels': 64,
|
||||
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||
'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||
}, num_vaes=4,
|
||||
)
|
||||
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
|
||||
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
|
||||
for i, qfile in enumerate(quants):
|
||||
quant_weights = torch.load(qfile)
|
||||
model.quantizer[i].load_state_dict(quant_weights, strict=True)
|
||||
torch.save(model.state_dict(), 'sample.pth')
|
||||
|
||||
print_network(model)
|
||||
o = model(clip, ts, cond)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
|
||||
|
||||
def test_ar_model():
|
||||
clip = torch.randn(2, 256, 400)
|
||||
cond = torch.randn(2, 256, 400)
|
||||
|
|
Loading…
Reference in New Issue
Block a user