forked from mrq/DL-Art-School
Pretrained vqvae option for tfd12..
This commit is contained in:
parent
1fde3e5a08
commit
47330d603b
|
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from models.arch_util import ResBlock
|
from models.arch_util import ResBlock
|
||||||
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
||||||
|
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepBlock
|
||||||
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \
|
from models.lucidrains.x_transformers import Encoder, Attention, RMSScaleShiftNorm, RotaryEmbedding, \
|
||||||
|
@ -304,7 +305,6 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
||||||
p.grad *= .2
|
p.grad *= .2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithARPrior(nn.Module):
|
class TransformerDiffusionWithARPrior(nn.Module):
|
||||||
def __init__(self, freeze_diff=False, **kwargs):
|
def __init__(self, freeze_diff=False, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -346,6 +346,66 @@ class TransformerDiffusionWithARPrior(nn.Module):
|
||||||
return diff
|
return diff
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDiffusionWithPretrainedVqvae(nn.Module):
|
||||||
|
def __init__(self, vqargs, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.internal_step = 0
|
||||||
|
self.diff = TransformerDiffusion(**kwargs)
|
||||||
|
self.quantizer = DiscreteVAE(**vqargs)
|
||||||
|
self.quantizer = self.quantizer.eval()
|
||||||
|
for p in self.quantizer.parameters():
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||||
|
with torch.no_grad():
|
||||||
|
reconstructed, proj = self.quantizer.infer(truth_mel)
|
||||||
|
proj = proj.permute(0,2,1)
|
||||||
|
|
||||||
|
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
||||||
|
return diff
|
||||||
|
|
||||||
|
def get_debug_values(self, step, __):
|
||||||
|
if self.quantizer.total_codes > 0:
|
||||||
|
return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes]}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers]))
|
||||||
|
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers]))
|
||||||
|
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers]))
|
||||||
|
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
|
||||||
|
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
|
||||||
|
groups = {
|
||||||
|
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])),
|
||||||
|
'blk1_attention_layers': attn1,
|
||||||
|
'blk2_attention_layers': attn2,
|
||||||
|
'attention_layers': attn1 + attn2,
|
||||||
|
'blk1_ff_layers': ff1,
|
||||||
|
'blk2_ff_layers': ff2,
|
||||||
|
'ff_layers': ff1 + ff2,
|
||||||
|
'block_out_layers': blkout_layers,
|
||||||
|
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
|
||||||
|
'out': list(self.diff.out.parameters()),
|
||||||
|
'x_proj': list(self.diff.inp_block.parameters()),
|
||||||
|
'layers': list(self.diff.layers.parameters()),
|
||||||
|
'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()),
|
||||||
|
'time_embed': list(self.diff.time_embed.parameters()),
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
|
||||||
|
def before_step(self, step):
|
||||||
|
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
||||||
|
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
|
||||||
|
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||||
|
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||||
|
# directly fiddling with the gradients.
|
||||||
|
for p in scaled_grad_parameters:
|
||||||
|
p.grad *= .2
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion12(opt_net, opt):
|
def register_transformer_diffusion12(opt_net, opt):
|
||||||
return TransformerDiffusion(**opt_net['kwargs'])
|
return TransformerDiffusion(**opt_net['kwargs'])
|
||||||
|
@ -360,6 +420,10 @@ def register_transformer_diffusion12_with_quantizer(opt_net, opt):
|
||||||
def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
|
def register_transformer_diffusion12_with_ar_prior(opt_net, opt):
|
||||||
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
|
return TransformerDiffusionWithARPrior(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
|
||||||
|
return TransformerDiffusionWithPretrainedVqvae(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
def test_quant_model():
|
def test_quant_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
|
@ -390,6 +454,31 @@ def test_quant_model():
|
||||||
print(t)
|
print(t)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vqvae_model():
|
||||||
|
clip = torch.randn(2, 100, 400)
|
||||||
|
cond = torch.randn(2,80,400)
|
||||||
|
ts = torch.LongTensor([600, 600])
|
||||||
|
|
||||||
|
# For music:
|
||||||
|
model = TransformerDiffusionWithPretrainedVqvae(in_channels=100, out_channels=200,
|
||||||
|
model_channels=1024, contraction_dim=512,
|
||||||
|
prenet_channels=1024, num_heads=8,
|
||||||
|
input_vec_dim=512, num_layers=12, prenet_layers=6,
|
||||||
|
dropout=.1, vqargs= {
|
||||||
|
'positional_dims': 1, 'channels': 80,
|
||||||
|
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
||||||
|
'num_layers': 2, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
quant_weights = torch.load('D:\\dlas\\experiments\\retrained_dvae_8192_clips.pth')
|
||||||
|
model.quantizer.load_state_dict(quant_weights, strict=True)
|
||||||
|
#torch.save(model.state_dict(), 'sample.pth')
|
||||||
|
|
||||||
|
print_network(model)
|
||||||
|
o = model(clip, ts, cond)
|
||||||
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
|
|
||||||
|
|
||||||
def test_ar_model():
|
def test_ar_model():
|
||||||
clip = torch.randn(2, 256, 400)
|
clip = torch.randn(2, 256, 400)
|
||||||
cond = torch.randn(2, 256, 400)
|
cond = torch.randn(2, 256, 400)
|
||||||
|
@ -414,4 +503,4 @@ def test_ar_model():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_quant_model()
|
test_vqvae_model()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user