diff --git a/codes/models/audio/music/gpt_music2.py b/codes/models/audio/music/gpt_music2.py index c8c38e3f..1e2afdd5 100644 --- a/codes/models/audio/music/gpt_music2.py +++ b/codes/models/audio/music/gpt_music2.py @@ -131,7 +131,6 @@ class GptMusicLower(nn.Module): return groups - @register_model def register_music_gpt_lower2(opt_net, opt): return GptMusicLower(**opt_get(opt_net, ['kwargs'], {})) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index ec39c921..3faf3f8b 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from models.arch_util import ResBlock +from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower from models.audio.music.music_quantizer2 import MusicQuantizer2 from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear @@ -168,11 +169,16 @@ class TransformerDiffusion(nn.Module): for p in self.parameters(): p.DO_NOT_TRAIN = True p.requires_grad = False - for m in [self.ar_input and self.ar_prior_intg]: - for p in m.parameters(): - del p.DO_NOT_TRAIN - p.requires_grad = True - + if hasattr(self, 'ar_input'): + for m in [self.ar_input and self.ar_prior_intg]: + for p in m.parameters(): + del p.DO_NOT_TRAIN + p.requires_grad = True + if hasattr(self, 'code_converter'): + for m in [self.code_converter and self.input_converter]: + for p in m.parameters(): + del p.DO_NOT_TRAIN + p.requires_grad = True self.debug_codes = {} @@ -502,6 +508,61 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): p.grad *= .2 +class TransformerDiffusionWithCheaterLatent(nn.Module): + def __init__(self, freeze_encoder_until=50000, **kwargs): + super().__init__() + self.internal_step = 0 + self.freeze_encoder_until = freeze_encoder_until + self.diff = TransformerDiffusion(**kwargs) + self.encoder = UpperEncoder(256, 1024, 256) + self.encoder = self.encoder.eval() + + def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): + encoder_grad_enabled = self.internal_step > self.freeze_encoder_until + with torch.set_grad_enabled(encoder_grad_enabled): + proj = self.encoder(truth_mel).permute(0,2,1) + + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + return diff + + def get_debug_values(self, step, __): + self.internal_step = step + + def get_grad_norm_parameter_groups(self): + attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers])) + attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers])) + ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers])) + ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers])) + blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + groups = { + 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])), + 'blk1_attention_layers': attn1, + 'blk2_attention_layers': attn2, + 'attention_layers': attn1 + attn2, + 'blk1_ff_layers': ff1, + 'blk2_ff_layers': ff2, + 'ff_layers': ff1 + ff2, + 'block_out_layers': blkout_layers, + 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()), + 'out': list(self.diff.out.parameters()), + 'x_proj': list(self.diff.inp_block.parameters()), + 'layers': list(self.diff.layers.parameters()), + 'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()), + 'time_embed': list(self.diff.time_embed.parameters()), + 'encoder': list(self.encoder.parameters()), + } + return groups + + def before_step(self, step): + scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \ + list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])) + # Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes + # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than + # directly fiddling with the gradients. + for p in scaled_grad_parameters: + p.grad *= .2 + + @register_model def register_transformer_diffusion12(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) @@ -524,6 +585,10 @@ def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt): def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt): return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs']) +@register_model +def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt): + return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs']) + def test_quant_model(): clip = torch.randn(2, 256, 400) @@ -646,6 +711,45 @@ def test_ar_model(): model(clip, ts, cond, conditioning_input=cond) +def test_cheater_model(): + clip = torch.randn(2, 256, 400) + ts = torch.LongTensor([600, 600]) + + # For music: + model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512, + model_channels=1024, contraction_dim=512, + prenet_channels=1024, num_heads=8, + input_vec_dim=256, num_layers=12, prenet_layers=6, + dropout=.1, + ) + diff_weights = torch.load('extracted_diff.pth') + model.diff.load_state_dict(diff_weights, strict=False) + cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\19500_generator_ema.pth') + cheater_ar = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4, + vqargs= {'positional_dims': 1, 'channels': 64, + 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, + 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + }) + cheater_ar.load_state_dict(cheater_ar_weights) + model.encoder.load_state_dict(cheater_ar.upper_encoder.state_dict(), strict=True) + torch.save(model.state_dict(), 'sample.pth') + + print_network(model) + o = model(clip, ts, clip) + pg = model.get_grad_norm_parameter_groups() + + +def extract_diff(in_f, out_f, remove_head=False): + p = torch.load(in_f) + out = {} + for k, v in p.items(): + if k.startswith('diff.'): + if remove_head and (k.startswith('diff.input_converter') or k.startswith('diff.code_converter')): + continue + out[k.replace('diff.', '')] = v + torch.save(out, out_f) + if __name__ == '__main__': - test_vqvae_model() + #extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True) + test_cheater_model()