diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 1ac133a2..cbf5a262 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -45,6 +45,7 @@ class AudioDiffusionFid(evaluator.Evaluator): self.diffusion_fn = self.perform_diffusion_tts elif mode == 'vocoder': self.dvae = load_speech_dvae() + self.dvae.eval() self.diffusion_fn = self.perform_diffusion_vocoder def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500): @@ -115,6 +116,8 @@ class AudioDiffusionFid(evaluator.Evaluator): projector = self.load_projector().to(self.env['device']) projector.eval() + if hasattr(self, 'dvae'): + self.dvae = self.dvae.to(self.env['device']) # Attempt to fix the random state as much as possible. RNG state will be restored before returning. rng_state = torch.get_rng_state() @@ -143,6 +146,8 @@ class AudioDiffusionFid(evaluator.Evaluator): frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size() self.model.train() + if hasattr(self, 'dvae'): + self.dvae = self.dvae.to('cpu') torch.set_rng_state(rng_state) return {"frechet_distance": frechet_distance}