From 55b9f3182595b5013fd0565ad2d1cdb29a6771ce Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 8 Jul 2022 19:51:12 -0600 Subject: [PATCH] fix mel outputs --- codes/trainer/eval/music_diffusion_fid.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 4a886983..c6f8d666 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -229,10 +229,12 @@ class MusicDiffusionFid(evaluator.Evaluator): # 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory. chunks = torch.split(gen_cheater, 64, dim=-1) + gen_mels = [] gen_wavs = [] for chunk in tqdm(chunks): gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,chunk.shape[-1]*16), progress=True, model_kwargs={'codes': chunk.permute(0,2,1)}) + gen_mels.append(gen_mel) # 3. And then the MEL back into a spectrogram output_shape = (1,16,audio.shape[-1]//(16*len(chunks))) @@ -242,6 +244,7 @@ class MusicDiffusionFid(evaluator.Evaluator): model_kwargs={'codes': gen_mel_denorm}) gen_wav = pixel_shuffle_1d(gen_wav, 16) gen_wavs.append(gen_wav) + gen_mel = torch.cat(gen_mels, dim=-1) gen_wav = torch.cat(gen_wavs, dim=-1) if audio.shape[-1] < 40 * 22050: