forked from mrq/DL-Art-School
f
This commit is contained in:
parent
68726eac74
commit
9a7bbf33df
|
@ -45,6 +45,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
self.diffusion_fn = self.perform_diffusion_tts
|
||||
elif mode == 'vocoder':
|
||||
self.dvae = load_speech_dvae()
|
||||
self.dvae.eval()
|
||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||
|
||||
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
|
||||
|
@ -115,6 +116,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
|
||||
projector = self.load_projector().to(self.env['device'])
|
||||
projector.eval()
|
||||
if hasattr(self, 'dvae'):
|
||||
self.dvae = self.dvae.to(self.env['device'])
|
||||
|
||||
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||
rng_state = torch.get_rng_state()
|
||||
|
@ -143,6 +146,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
|
||||
|
||||
self.model.train()
|
||||
if hasattr(self, 'dvae'):
|
||||
self.dvae = self.dvae.to('cpu')
|
||||
torch.set_rng_state(rng_state)
|
||||
|
||||
return {"frechet_distance": frechet_distance}
|
||||
|
|
Loading…
Reference in New Issue
Block a user