diff --git a/codes/models/audio/music/unet_diffusion_music_codes.py b/codes/models/audio/music/unet_diffusion_music_codes.py index 60385d0f..95927eec 100644 --- a/codes/models/audio/music/unet_diffusion_music_codes.py +++ b/codes/models/audio/music/unet_diffusion_music_codes.py @@ -714,7 +714,7 @@ class UNetMusicModelWithQuantizer(nn.Module): self.m2v.min_gumbel_temperature, ) - def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_free=False): + def forward(self, x, timesteps, truth_mel, disable_diversity=False, conditioning_input=None, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True) @@ -749,9 +749,14 @@ if __name__ == '__main__': clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=512, num_res_blocks=3, input_vec_dim=1024, + model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024, attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1, - use_scale_shift_norm=True, dropout=.1, num_heads=8) + use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4) print_network(model) + + quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth') + model.m2v.load_state_dict(quant_weights, strict=False) + torch.save(model.state_dict(), 'sample.pth') + model(clip, ts, cond)