From 8437bb0c53a3e8e575d49a0ef5c35d145bbbb594 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 15 Mar 2022 23:52:48 -0600 Subject: [PATCH] fixes --- codes/trainer/eval/audio_diffusion_fid.py | 3 +++ codes/trainer/injectors/audio_injectors.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 196a642c..83a96e37 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -53,6 +53,7 @@ class AudioDiffusionFid(evaluator.Evaluator): self.local_modules['dvae'] = load_speech_dvae().cpu() self.diffusion_fn = self.perform_diffusion_vocoder elif mode == 'tts9_mel': + mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth') self.local_modules['dvae'] = load_speech_dvae().cpu() self.local_modules['vocoder'] = load_univnet_vocoder().cpu() self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes @@ -136,6 +137,8 @@ class AudioDiffusionFid(evaluator.Evaluator): gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': univnet_mel}) + # denormalize mel + gen_mel = ((gen_mel+1)/2)*(self.mel_max-self.mel_min)+self.mel_min gen_wav = self.local_modules['vocoder'].inference(gen_mel) real_dec = self.local_modules['vocoder'].inference(univnet_mel) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index e255cd3e..dbf858de 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -28,7 +28,7 @@ class MelSpectrogramInjector(Injector): self.mel_max = mel_max self.mel_min = mel_min else: - self.mel_norms = None + self.mel_max = None def forward(self, state): inp = state[self.input] @@ -37,7 +37,7 @@ class MelSpectrogramInjector(Injector): assert len(inp.shape) == 2 self.stft = self.stft.to(inp.device) mel = self.stft.mel_spectrogram(inp) - if self.mel_norms is not None: + if self.mel_max is not None: mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1 return {self.output: mel}