diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py index e1c66852..469c15cb 100644 --- a/codes/models/audio/music/transformer_diffusion7.py +++ b/codes/models/audio/music/transformer_diffusion7.py @@ -199,9 +199,11 @@ class TransformerDiffusion(nn.Module): class TransformerDiffusionWithQuantizer(nn.Module): - def __init__(self, **kwargs): + def __init__(self, freeze_quantizer_until=20000, **kwargs): super().__init__() + self.internal_step = 0 + self.freeze_quantizer_until = freeze_quantizer_until self.diff = TransformerDiffusion(**kwargs) from models.audio.mel2vec import ContrastiveTrainingWrapper self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024) @@ -210,13 +212,24 @@ class TransformerDiffusionWithQuantizer(nn.Module): def update_for_step(self, step, *args): self.internal_step = step + qstep = max(0, self.internal_step - self.freeze_quantizer_until) self.m2v.quantizer.temperature = max( - self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**step, + self.m2v.max_gumbel_temperature * self.m2v.gumbel_temperature_decay**qstep, self.m2v.min_gumbel_temperature, ) def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): - proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1) + quant_grad_enabled = self.internal_step > self.freeze_quantizer_until + with torch.set_grad_enabled(quant_grad_enabled): + proj = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1) + + # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. + if not quant_grad_enabled: + unused = 0 + for p in self.m2v.parameters(): + unused = unused + p.mean() * 0 + proj = proj + unused + return self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) @@ -256,12 +269,12 @@ if __name__ == '__main__': ts = torch.LongTensor([600, 600]) model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=2048, num_layers=16) - quant_weights = torch.load('X:\\dlas\\experiments\\train_music_quant\\models\\1000_generator.pth') + #quant_weights = torch.load('X:\\dlas\\experiments\\train_music_quant\\models\\1000_generator.pth') #diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') - model.m2v.load_state_dict(quant_weights, strict=False) + #model.m2v.load_state_dict(quant_weights, strict=False) #model.diff.load_state_dict(diff_weights) - torch.save(model.state_dict(), 'sample.pth') + #torch.save(model.state_dict(), 'sample.pth') print_network(model) o = model(clip, ts, clip, cond)