diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 8fb8b2c0..b5e010c1 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -230,9 +230,10 @@ class MusicDiffusionFid(evaluator.Evaluator): # 1. Generate the cheater latent using the input as a reference. sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop - gen_cheater = sampler(self.model, cheater.shape, progress=True, + output_shape = (1, 256, cheater.shape[-1]-80) + gen_cheater = sampler(self.model, output_shape, progress=True, causal=self.causal, causal_slope=self.causal_slope, - model_kwargs={'conditioning_input': cheater}) + model_kwargs={'conditioning_input': cheater, 'cond_start': 40}) # 2. Decode the cheater into a MEL gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,