diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index 515093b0..9ed43e8f 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -58,9 +58,11 @@ class UpperConditioningEncoder(nn.Module): class GptMusicLower(nn.Module): - def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True): + def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, + num_upper_groups=4, fp16=True, freeze_upper_until=0): super().__init__() self.internal_step = 0 + self.freeze_upper_until = freeze_upper_until self.num_groups = num_target_groups self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) @@ -103,10 +105,16 @@ class GptMusicLower(nn.Module): def forward(self, mel, conditioning, return_latent=False): + unused_params = [] with torch.no_grad(): self.target_quantizer.eval() codes = self.target_quantizer.get_codes(mel) - upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) + if self.freeze_upper_until > self.internal_step: + with torch.no_grad(): + upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) + unused_params.extend(list(self.upper_quantizer.parameters())) + else: + upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.) upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear') upper_vector = upper_vector.permute(0,2,1) @@ -135,6 +143,11 @@ class GptMusicLower(nn.Module): loss = F.cross_entropy(logits, targets[:,:,i]) losses = losses + loss + unused_adder = 0 + for p in unused_params: + unused_adder = unused_adder + p.mean() * 0 + losses = losses + unused_adder + return losses / self.num_groups, upper_diversity def get_grad_norm_parameter_groups(self): @@ -256,7 +269,7 @@ def test_lower(): dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000) base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu'))) - model = GptMusicLower(512, 8, fp16=False) + model = GptMusicLower(512, 8, fp16=False, freeze_upper_until=100) model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False) torch.save(model.state_dict(), "sample.pth") mel = torch.randn(2,256,400)