Load diffusion_fid DVAE into the correct cuda device

This commit is contained in:
James Betker 2022-03-04 13:42:14 -07:00
parent e1052a5e32
commit 382681a35d

View File

@ -48,7 +48,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
if mode == 'tts':
self.diffusion_fn = self.perform_diffusion_tts
elif mode == 'vocoder':
self.dvae = load_speech_dvae()
self.dvae = load_speech_dvae().to(self.env['device'])
self.dvae.eval()
self.diffusion_fn = self.perform_diffusion_vocoder