forked from mrq/DL-Art-School
fixes
This commit is contained in:
parent
3f244f6a68
commit
8437bb0c53
|
@ -53,6 +53,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||||
elif mode == 'tts9_mel':
|
elif mode == 'tts9_mel':
|
||||||
|
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
|
||||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||||
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||||
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
||||||
|
@ -136,6 +137,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape,
|
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||||
model_kwargs={'aligned_conditioning': mel_codes,
|
model_kwargs={'aligned_conditioning': mel_codes,
|
||||||
'conditioning_input': univnet_mel})
|
'conditioning_input': univnet_mel})
|
||||||
|
# denormalize mel
|
||||||
|
gen_mel = ((gen_mel+1)/2)*(self.mel_max-self.mel_min)+self.mel_min
|
||||||
|
|
||||||
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
|
gen_wav = self.local_modules['vocoder'].inference(gen_mel)
|
||||||
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
real_dec = self.local_modules['vocoder'].inference(univnet_mel)
|
||||||
|
|
|
@ -28,7 +28,7 @@ class MelSpectrogramInjector(Injector):
|
||||||
self.mel_max = mel_max
|
self.mel_max = mel_max
|
||||||
self.mel_min = mel_min
|
self.mel_min = mel_min
|
||||||
else:
|
else:
|
||||||
self.mel_norms = None
|
self.mel_max = None
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
inp = state[self.input]
|
inp = state[self.input]
|
||||||
|
@ -37,7 +37,7 @@ class MelSpectrogramInjector(Injector):
|
||||||
assert len(inp.shape) == 2
|
assert len(inp.shape) == 2
|
||||||
self.stft = self.stft.to(inp.device)
|
self.stft = self.stft.to(inp.device)
|
||||||
mel = self.stft.mel_spectrogram(inp)
|
mel = self.stft.mel_spectrogram(inp)
|
||||||
if self.mel_norms is not None:
|
if self.mel_max is not None:
|
||||||
mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1
|
mel = 2 * ((mel - self.mel_min) / (self.mel_max - self.mel_min)) - 1
|
||||||
return {self.output: mel}
|
return {self.output: mel}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user