forked from mrq/DL-Art-School
Load diffusion_fid DVAE into the correct cuda device
This commit is contained in:
parent
e1052a5e32
commit
382681a35d
|
@ -48,7 +48,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
if mode == 'tts':
|
if mode == 'tts':
|
||||||
self.diffusion_fn = self.perform_diffusion_tts
|
self.diffusion_fn = self.perform_diffusion_tts
|
||||||
elif mode == 'vocoder':
|
elif mode == 'vocoder':
|
||||||
self.dvae = load_speech_dvae()
|
self.dvae = load_speech_dvae().to(self.env['device'])
|
||||||
self.dvae.eval()
|
self.dvae.eval()
|
||||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user