diff --git a/codes/trainer/eval/sr_diffusion_fid.py b/codes/trainer/eval/sr_diffusion_fid.py index 172d2561..7730f8cb 100644 --- a/codes/trainer/eval/sr_diffusion_fid.py +++ b/codes/trainer/eval/sr_diffusion_fid.py @@ -9,7 +9,7 @@ import trainer.eval.evaluator as evaluator from pytorch_fid import fid_score from data import create_dataset -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler from trainer.injectors.gaussian_diffusion_injector import GaussianDiffusionInferenceInjector from utils.util import opt_get @@ -23,6 +23,10 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator): self.fid_batch_size = opt_get(opt_eval, ['fid_batch_size'], 64) assert self.batch_sz is not None self.dataset = create_dataset(opt_eval['dataset']) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + self.sampler = DistributedSampler(self.dataset, shuffle=False, drop_last=True) + else: + self.sampler = SequentialSampler(self.dataset) self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset. assert isinstance(self.fid_real_samples, str) self.gd = GaussianDiffusionInferenceInjector(opt_eval['diffusion_params'], env) @@ -31,7 +35,7 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator): def perform_eval(self): # Attempt to make the dataset deterministic. self.dataset.reset_random() - dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=0) + dataloader = DataLoader(self.dataset, self.batch_sz, sampler=self.sampler, num_workers=0) fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"])) os.makedirs(fid_fake_path, exist_ok=True)