forked from mrq/DL-Art-School
Load diffusion_fid DVAE into the correct cuda device
This commit is contained in:
parent
e1052a5e32
commit
382681a35d
|
@ -48,7 +48,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
if mode == 'tts':
|
||||
self.diffusion_fn = self.perform_diffusion_tts
|
||||
elif mode == 'vocoder':
|
||||
self.dvae = load_speech_dvae()
|
||||
self.dvae = load_speech_dvae().to(self.env['device'])
|
||||
self.dvae.eval()
|
||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user