From 6b32c87dcb2343a67ea7edcab074986eff19a1fa Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 14 Jun 2021 09:27:43 -0600 Subject: [PATCH] Try to make diffusion fid more deterministic --- codes/data/image_corruptor.py | 24 ++++++++++++++++-------- codes/data/image_folder_dataset.py | 3 +++ codes/trainer/eval/sr_diffusion_fid.py | 8 ++++++-- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index a02d63b1..58073161 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -8,12 +8,6 @@ from PIL import Image from io import BytesIO -# Feeds a random uniform through a cosine distribution to slightly bias corruptions towards "uncorrupted". -# Return is on [0,1] with a bias towards 0. -def get_rand(): - r = random.random() - return 1 - cos(r * pi / 2) - # Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data) ''' if __name__ == '__main__': @@ -28,12 +22,26 @@ if __name__ == '__main__': # options. class ImageCorruptor: def __init__(self, opt): + self.opt = opt self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1 self.fixed_corruptions = opt['fixed_corruptions'] if 'fixed_corruptions' in opt.keys() else [] self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 0 if self.num_corrupts == 0: return self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] + self.reset_random() + + def reset_random(self): + if 'random_seed' in self.opt.keys(): + self.rand = random.Random(self.opt['random_seed']) + else: + self.rand = random.Random() + + # Feeds a random uniform through a cosine distribution to slightly bias corruptions towards "uncorrupted". + # Return is on [0,1] with a bias towards 0. + def get_rand(self): + r = self.rand.random() + return 1 - cos(r * pi / 2) def corrupt_images(self, imgs, return_entropy=False): if self.num_corrupts == 0 and not self.fixed_corruptions: @@ -53,10 +61,10 @@ class ImageCorruptor: applied_augs = augmentations + self.fixed_corruptions for img in imgs: for aug in augmentations: - r = get_rand() + r = self.get_rand() img = self.apply_corruption(img, aug, r, applied_augs) for aug in self.fixed_corruptions: - r = get_rand() + r = self.get_rand() img = self.apply_corruption(img, aug, r, applied_augs) entropy.append(r) corrupted_imgs.append(img) diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index 8b7415c9..db51c383 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -124,6 +124,9 @@ class ImageFolderDataset: ls, ent = self.corruptor.corrupt_images(ls, return_entropy=True) return ls, ent + def reset_random(self): + self.corruptor.reset_random() + def __len__(self): return self.len diff --git a/codes/trainer/eval/sr_diffusion_fid.py b/codes/trainer/eval/sr_diffusion_fid.py index 13e50ea3..172d2561 100644 --- a/codes/trainer/eval/sr_diffusion_fid.py +++ b/codes/trainer/eval/sr_diffusion_fid.py @@ -15,6 +15,7 @@ from trainer.injectors.gaussian_diffusion_injector import GaussianDiffusionInfer from utils.util import opt_get +# Performs a FID evaluation on a diffusion network class SrDiffusionFidEvaluator(evaluator.Evaluator): def __init__(self, model, opt_eval, env): super().__init__(model, opt_eval, env) @@ -24,15 +25,18 @@ class SrDiffusionFidEvaluator(evaluator.Evaluator): self.dataset = create_dataset(opt_eval['dataset']) self.fid_real_samples = opt_eval['dataset']['paths'] # This is assumed to exist for the given dataset. assert isinstance(self.fid_real_samples, str) - self.dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=1) self.gd = GaussianDiffusionInferenceInjector(opt_eval['diffusion_params'], env) self.out_key = opt_eval['diffusion_params']['out'] def perform_eval(self): + # Attempt to make the dataset deterministic. + self.dataset.reset_random() + dataloader = DataLoader(self.dataset, self.batch_sz, shuffle=False, num_workers=0) + fid_fake_path = osp.join(self.env['base_path'], "..", "fid", str(self.env["step"])) os.makedirs(fid_fake_path, exist_ok=True) counter = 0 - for batch in tqdm(self.dataloader): + for batch in tqdm(dataloader): batch = {k: v.to(self.env['device']) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} gen = self.gd(batch)[self.out_key]