diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 62e4ccbc..89c5059c 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -48,7 +48,7 @@ class AudioDiffusionFid(evaluator.Evaluator): if mode == 'tts': self.diffusion_fn = self.perform_diffusion_tts elif mode == 'vocoder': - self.dvae = load_speech_dvae() + self.dvae = load_speech_dvae().to(self.env['device']) self.dvae.eval() self.diffusion_fn = self.perform_diffusion_vocoder