From 8c8efbe1319c8f438300bffea9c9a5a00518279a Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 17:54:08 -0600 Subject: [PATCH] fix code_emb --- .../audio/music/transformer_diffusion12.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 58a11870..e1b748ca 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -210,8 +210,8 @@ class TransformerDiffusion(nn.Module): def timestep_independent(self, prior, expected_seq_len): if self.new_code_expansion: - code_emb = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) - code_emb = self.ar_input(code_emb) if self.ar_prior else self.input_converter(code_emb) + prior = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) + code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. @@ -732,14 +732,14 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512, - model_channels=1024, contraction_dim=512, - prenet_channels=1024, num_heads=8, - input_vec_dim=256, num_layers=12, prenet_layers=6, + model_channels=1536, contraction_dim=768, + prenet_channels=1024, num_heads=12, + input_vec_dim=256, num_layers=20, prenet_layers=6, dropout=.1, new_code_expansion=True, ) - diff_weights = torch.load('extracted_diff.pth') - model.diff.load_state_dict(diff_weights, strict=False) - cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\19500_generator_ema.pth') + #diff_weights = torch.load('extracted_diff.pth') + #model.diff.load_state_dict(diff_weights, strict=False) + cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\60000_generator_ema.pth') cheater_ar = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4, vqargs= {'positional_dims': 1, 'channels': 64, 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,