forked from mrq/DL-Art-School
fix mel outputs
This commit is contained in:
parent
b99af89c8f
commit
55b9f31825
|
@ -229,10 +229,12 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
# 2. Decode the cheater into a MEL. This operation and the next need to be chunked to make them feasible to perform within GPU memory.
|
||||
chunks = torch.split(gen_cheater, 64, dim=-1)
|
||||
gen_mels = []
|
||||
gen_wavs = []
|
||||
for chunk in tqdm(chunks):
|
||||
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,chunk.shape[-1]*16), progress=True,
|
||||
model_kwargs={'codes': chunk.permute(0,2,1)})
|
||||
gen_mels.append(gen_mel)
|
||||
|
||||
# 3. And then the MEL back into a spectrogram
|
||||
output_shape = (1,16,audio.shape[-1]//(16*len(chunks)))
|
||||
|
@ -242,6 +244,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
model_kwargs={'codes': gen_mel_denorm})
|
||||
gen_wav = pixel_shuffle_1d(gen_wav, 16)
|
||||
gen_wavs.append(gen_wav)
|
||||
gen_mel = torch.cat(gen_mels, dim=-1)
|
||||
gen_wav = torch.cat(gen_wavs, dim=-1)
|
||||
|
||||
if audio.shape[-1] < 40 * 22050:
|
||||
|
|
Loading…
Reference in New Issue
Block a user