forked from mrq/DL-Art-School
put frechet_distance on cuda
This commit is contained in:
parent
9a7bbf33df
commit
81017d9696
|
@ -140,7 +140,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
|
torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate)
|
||||||
gen_projections = torch.stack(gen_projections, dim=0)
|
gen_projections = torch.stack(gen_projections, dim=0)
|
||||||
real_projections = torch.stack(real_projections, dim=0)
|
real_projections = torch.stack(real_projections, dim=0)
|
||||||
frechet_distance = self.compute_frechet_distance(gen_projections, real_projections)
|
frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device'])
|
||||||
|
|
||||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
|
frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user