gurgl
This commit is contained in:
parent
e6824e398f
commit
7c17c8e674
|
@ -143,7 +143,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
||||||
|
|
||||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
|
distributed.all_reduce(frechet_distance)
|
||||||
|
frechet_distance = frechet_distance / distributed.get_world_size()
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
if hasattr(self, 'dvae'):
|
if hasattr(self, 'dvae'):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user