Load diffusion_fid DVAE into the correct cuda device

This commit is contained in:
James Betker 2022-03-04 13:42:14 -07:00
parent e1052a5e32
commit 382681a35d

View File

@ -48,7 +48,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
if mode == 'tts': if mode == 'tts':
self.diffusion_fn = self.perform_diffusion_tts self.diffusion_fn = self.perform_diffusion_tts
elif mode == 'vocoder': elif mode == 'vocoder':
self.dvae = load_speech_dvae() self.dvae = load_speech_dvae().to(self.env['device'])
self.dvae.eval() self.dvae.eval()
self.diffusion_fn = self.perform_diffusion_vocoder self.diffusion_fn = self.perform_diffusion_vocoder