fid saving images across all rank fix

This commit is contained in:
James Betker 2021-06-15 10:31:07 -06:00
parent 6a75bd0777
commit ae8de0cb9d

View File

@ -50,6 +50,7 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator):
torch.distributed.all_gather(gather_list, gen)
gen = torch.cat(gather_list, dim=0)
if self.env['rank'] <= 0:
for g in gen:
torchvision.utils.save_image(g, osp.join(fid_fake_path, f"{counter}.png"))
counter += 1