This commit is contained in:
James Betker 2022-03-15 23:52:48 -06:00
parent 3f244f6a68
commit 8437bb0c53
2 changed files with 5 additions and 2 deletions

View File

@ -53,6 +53,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
self.local_modules['dvae'] = load_speech_dvae().cpu()
self.diffusion_fn = self.perform_diffusion_vocoder
elif mode == 'tts9_mel':
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
self.local_modules['dvae'] = load_speech_dvae().cpu()
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
@ -136,6 +137,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape,
model_kwargs={'aligned_conditioning': mel_codes,
'conditioning_input': univnet_mel})
# denormalize mel
gen_mel = ((gen_mel+1)/2)*(self.mel_max-self.mel_min)+self.mel_min
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
real_dec = self.local_modules['vocoder'].inference(univnet_mel)

View File

@ -28,7 +28,7 @@ class MelSpectrogramInjector(Injector):
self.mel_max = mel_max
self.mel_min = mel_min
else:
self.mel_norms = None
self.mel_max = None
def forward(self, state):
inp = state[self.input]
@ -37,7 +37,7 @@ class MelSpectrogramInjector(Injector):
assert len(inp.shape) == 2
self.stft = self.stft.to(inp.device)
mel = self.stft.mel_spectrogram(inp)
if self.mel_norms is not None:
if self.mel_max is not None:
mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1
return {self.output: mel}