Fix a bug where non-rank-0 is computing FID before all images are saved.

This commit is contained in:
James Betker 2021-06-16 16:27:09 -06:00
parent 68cbbed886
commit 8e3a33e001

View File

@ -55,5 +55,8 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator):
torchvision.utils.save_image(g, osp.join(fid_fake_path, f"{counter}.png"))
counter += 1
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.fid_batch_size,
True, 2048)}
if self.env['rank'] <= 0:
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.fid_batch_size,
True, 2048)}
else:
return {}