diff --git a/codes/trainer/eval/sr_diffusion_fid.py b/codes/trainer/eval/sr_diffusion_fid.py index 7730f8cb..eba49cd7 100644 --- a/codes/trainer/eval/sr_diffusion_fid.py +++ b/codes/trainer/eval/sr_diffusion_fid.py @@ -50,8 +50,8 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator): torch.distributed.all_gather(gather_list, gen) gen = torch.cat(gather_list, dim=0) - for b in range(self.batch_sz): - torchvision.utils.save_image(gen[b], osp.join(fid_fake_path, "%i_.png" % (counter))) + for g in gen: + torchvision.utils.save_image(g, osp.join(fid_fake_path, f"{counter}.png")) counter += 1 return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.fid_batch_size,