From 81017d96963489748706e659a9841b6058a7b3a9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 23 Feb 2022 21:21:13 -0700 Subject: [PATCH] put frechet_distance on cuda --- codes/trainer/eval/audio_diffusion_fid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index cbf5a262..37b7d905 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -140,7 +140,7 @@ class AudioDiffusionFid(evaluator.Evaluator): torchaudio.save(os.path.join(save_path, f"{self.env['rank']}_{i}_real.wav"), ref.squeeze(0).cpu(), sample_rate) gen_projections = torch.stack(gen_projections, dim=0) real_projections = torch.stack(real_projections, dim=0) - frechet_distance = self.compute_frechet_distance(gen_projections, real_projections) + frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) if distributed.is_initialized() and distributed.get_world_size() > 1: frechet_distance = distributed.all_reduce(frechet_distance) / distributed.get_world_size()