This commit is contained in:
James Betker 2022-02-23 21:28:24 -07:00
parent e6824e398f
commit 7c17c8e674

View File

@ -143,7 +143,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
if distributed.is_initialized() and distributed.get_world_size() > 1: if distributed.is_initialized() and distributed.get_world_size() > 1:
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size() distributed.all_reduce(frechet_distance)
frechet_distance = frechet_distance / distributed.get_world_size()
self.model.train() self.model.train()
if hasattr(self, 'dvae'): if hasattr(self, 'dvae'):