forked from mrq/DL-Art-School
fid saving images across all rank fix
This commit is contained in:
parent
6a75bd0777
commit
ae8de0cb9d
|
@ -50,9 +50,10 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator):
|
|||
torch.distributed.all_gather(gather_list, gen)
|
||||
gen = torch.cat(gather_list, dim=0)
|
||||
|
||||
for g in gen:
|
||||
torchvision.utils.save_image(g, osp.join(fid_fake_path, f"{counter}.png"))
|
||||
counter += 1
|
||||
if self.env['rank'] <= 0:
|
||||
for g in gen:
|
||||
torchvision.utils.save_image(g, osp.join(fid_fake_path, f"{counter}.png"))
|
||||
counter += 1
|
||||
|
||||
return {"fid": fid_score.calculate_fid_given_paths([self.fid_real_samples, fid_fake_path], self.fid_batch_size,
|
||||
True, 2048)}
|
Loading…
Reference in New Issue
Block a user