Distributed FID dataset across processes

This commit is contained in:
James Betker 2021-06-14 09:33:44 -06:00
parent 6b32c87dcb
commit 545f2db170

View File

@ -9,7 +9,7 @@ import trainer.eval.evaluator as evaluator
from pytorch_fid import fid_score
from data import create_dataset
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
from trainer.injectors.gaussian_diffusion_injector import GaussianDiffusionInferenceInjector
from utils.util import opt_get
@ -23,6 +23,10 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator):
self.fid_batch_size = opt_get(opt_eval, ['fid_batch_size'], 64)
assert self.batch_sz is not None
self.dataset = create_dataset(opt_eval['dataset'])
if torch.distributed.is_available() and torch.distributed.is_initialized():
self.sampler = DistributedSampler(self.dataset, shuffle=False, drop_last=True)
else:
self.sampler = SequentialSampler(self.dataset)
self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset.
assert isinstance(self.fid_real_samples, str)
self.gd = GaussianDiffusionInferenceInjector(opt_eval['diffusion_params'], env)
@ -31,7 +35,7 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator):
def perform_eval(self):
# Attempt to make the dataset deterministic.
self.dataset.reset_random()
dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=0)
dataloader = DataLoader(self.dataset, self.batch_sz, sampler=self.sampler, num_workers=0)
fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"]))
os.makedirs(fid_fake_path, exist_ok=True)