fix mel outputs

This commit is contained in:
James Betker 2022-07-08 19:51:12 -06:00
parent b99af89c8f
commit 55b9f31825

View File

@ -229,10 +229,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
# 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory.
chunks = torch.split(gen_cheater, 64, dim=-1)
gen_mels = []
gen_wavs = []
for chunk in tqdm(chunks):
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,chunk.shape[-1]*16), progress=True,
model_kwargs={'codes': chunk.permute(0,2,1)})
gen_mels.append(gen_mel)
# 3. And then the MEL back into a spectrogram
output_shape = (1,16,audio.shape[-1]//(16*len(chunks)))
@ -242,6 +244,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
model_kwargs={'codes': gen_mel_denorm})
gen_wav = pixel_shuffle_1d(gen_wav, 16)
gen_wavs.append(gen_wav)
gen_mel = torch.cat(gen_mels, dim=-1)
gen_wav = torch.cat(gen_wavs, dim=-1)
if audio.shape[-1] < 40 * 22050: