forked from mrq/DL-Art-School
fixes
This commit is contained in:
parent
3f244f6a68
commit
8437bb0c53
|
@ -53,6 +53,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||
elif mode == 'tts9_mel':
|
||||
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
|
||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
||||
|
@ -136,6 +137,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||
model_kwargs={'aligned_conditioning': mel_codes,
|
||||
'conditioning_input': univnet_mel})
|
||||
# denormalize mel
|
||||
gen_mel = ((gen_mel+1)/2)*(self.mel_max-self.mel_min)+self.mel_min
|
||||
|
||||
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
|
||||
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
||||
|
|
|
@ -28,7 +28,7 @@ class MelSpectrogramInjector(Injector):
|
|||
self.mel_max = mel_max
|
||||
self.mel_min = mel_min
|
||||
else:
|
||||
self.mel_norms = None
|
||||
self.mel_max = None
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
|
@ -37,7 +37,7 @@ class MelSpectrogramInjector(Injector):
|
|||
assert len(inp.shape) == 2
|
||||
self.stft = self.stft.to(inp.device)
|
||||
mel = self.stft.mel_spectrogram(inp)
|
||||
if self.mel_norms is not None:
|
||||
if self.mel_max is not None:
|
||||
mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1
|
||||
return {self.output: mel}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user