This commit is contained in:
James Betker 2022-02-23 18:03:38 -07:00
parent 68726eac74
commit 9a7bbf33df

View File

@ -45,6 +45,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
self.diffusion_fn = self.perform_diffusion_tts self.diffusion_fn = self.perform_diffusion_tts
elif mode == 'vocoder': elif mode == 'vocoder':
self.dvae = load_speech_dvae() self.dvae = load_speech_dvae()
self.dvae.eval()
self.diffusion_fn = self.perform_diffusion_vocoder self.diffusion_fn = self.perform_diffusion_vocoder
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500): def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
@ -115,6 +116,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
projector = self.load_projector().to(self.env['device']) projector = self.load_projector().to(self.env['device'])
projector.eval() projector.eval()
if hasattr(self, 'dvae'):
self.dvae = self.dvae.to(self.env['device'])
# Attempt to fix the random state as much as possible. RNG state will be restored before returning. # Attempt to fix the random state as much as possible. RNG state will be restored before returning.
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
@ -143,6 +146,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size() frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
self.model.train() self.model.train()
if hasattr(self, 'dvae'):
self.dvae = self.dvae.to('cpu')
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
return {"frechet_distance": frechet_distance} return {"frechet_distance": frechet_distance}