forked from mrq/DL-Art-School
f
This commit is contained in:
parent
68726eac74
commit
9a7bbf33df
|
@ -45,6 +45,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
self.diffusion_fn = self.perform_diffusion_tts
|
self.diffusion_fn = self.perform_diffusion_tts
|
||||||
elif mode == 'vocoder':
|
elif mode == 'vocoder':
|
||||||
self.dvae = load_speech_dvae()
|
self.dvae = load_speech_dvae()
|
||||||
|
self.dvae.eval()
|
||||||
self.diffusion_fn = self.perform_diffusion_vocoder
|
self.diffusion_fn = self.perform_diffusion_vocoder
|
||||||
|
|
||||||
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
|
def perform_diffusion_tts(self, audio, codes, text, sample_rate=5500):
|
||||||
|
@ -115,6 +116,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
projector = self.load_projector().to(self.env['device'])
|
projector = self.load_projector().to(self.env['device'])
|
||||||
projector.eval()
|
projector.eval()
|
||||||
|
if hasattr(self, 'dvae'):
|
||||||
|
self.dvae = self.dvae.to(self.env['device'])
|
||||||
|
|
||||||
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
# Attempt to fix the random state as much as possible. RNG state will be restored before returning.
|
||||||
rng_state = torch.get_rng_state()
|
rng_state = torch.get_rng_state()
|
||||||
|
@ -143,6 +146,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
|
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
if hasattr(self, 'dvae'):
|
||||||
|
self.dvae = self.dvae.to('cpu')
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
|
|
||||||
return {"frechet_distance": frechet_distance}
|
return {"frechet_distance": frechet_distance}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user