diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 37b7d905..5f8e892c 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -143,7 +143,8 @@ class AudioDiffusionFid(evaluator.Evaluator): frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) if distributed.is_initialized() and distributed.get_world_size() > 1: - frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size() + distributed.all_reduce(frechet_distance) + frechet_distance = frechet_distance / distributed.get_world_size() self.model.train() if hasattr(self, 'dvae'):