diff --git a/codes/scripts/audio/prep_music/test_contrastive_music_pairer.py b/codes/scripts/audio/prep_music/test_contrastive_music_pairer.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 3c62c0da..0331da94 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -52,7 +52,7 @@ class MusicDiffusionFid(evaluator.Evaluator): dropout=0, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, unconditioned_percentage=0) self.spec_decoder.load_state_dict(torch.load('../experiments/music_waveform_gen.pth', map_location=torch.device('cpu'))) self.projector = ContrastiveAudio(model_dim=512, transformer_heads=8, dropout=0, encoder_depth=8, mel_channels=256) - #self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu'))) + self.projector.load_state_dict(torch.load('../experiments/music_eval_projector.pth', map_location=torch.device('cpu'))) self.local_modules = {'spec_decoder': self.spec_decoder, 'projector': self.projector} if mode == 'spec_decode':