From afa41f18043bb6fc9ec5cae34a39aafe83e891a5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 30 Jun 2021 09:44:46 -0600 Subject: [PATCH] Allow hq color jittering and corruptions that are not included in the corruption factor --- codes/data/image_corruptor.py | 16 +++++++++++----- codes/data/image_folder_dataset.py | 10 ++++++++-- codes/models/diffusion/unet_diffusion.py | 7 ++++++- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 6af7d000..f7000b29 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -25,6 +25,16 @@ if __name__ == '__main__': plt.show() ''' + +def kornia_color_jitter_numpy(img, setting): + if setting * 255 > 1: + # I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format. + img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0) + img = ColorJitter(setting, setting, setting, setting)(img) + img = img.squeeze(0).permute(1,2,0).numpy() + return img + + # Performs image corruption on a list of images from a configurable set of corruption # options. class ImageCorruptor: @@ -107,11 +117,7 @@ class ImageCorruptor: lo_end = 0 hi_end = .2 setting = rand_val * (hi_end - lo_end) + lo_end - if setting * 255 > 1: - # I'm using Kornia's ColorJitter, which requires pytorch arrays in b,c,h,w format. - img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0) - img = ColorJitter(setting, setting, setting, setting)(img) - img = img.squeeze(0).permute(1,2,0).numpy() + img = kornia_color_jitter_numpy(img, setting) elif 'gaussian_blur' in aug: img = cv2.GaussianBlur(img, (0,0), self.blur_scale*rand_val*1.5) elif 'motion_blur' in aug: diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index e263d8f5..63db243e 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -17,7 +17,7 @@ from tqdm import tqdm from data import util # Builds a dataset created from a simple folder containing a list of training/test/validation images. -from data.image_corruptor import ImageCorruptor +from data.image_corruptor import ImageCorruptor, kornia_color_jitter_numpy from data.image_label_parser import VsNetImageLabeler from utils.util import opt_get @@ -50,6 +50,7 @@ class ImageFolderDataset: self.rgb_n1_to_1 = opt_get(opt, ['rgb_n1_to_1'], False) self.force_square = opt_get(opt, ['force_square'], True) self.fixed_parameters = {k: torch.tensor(v) for k, v in opt_get(opt, ['fixed_parameters'], {}).items()} + self.all_image_color_jitter = opt_get(opt, ['all_image_color_jitter'], 0) if 'normalize' in opt.keys(): if opt['normalize'] == 'stylegan2_norm': self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) @@ -155,6 +156,10 @@ class ImageFolderDataset: dim = min(h, w) hq = hq[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :] + # Perform color jittering on the HQ image if specified. The given value should be between [0,1]. + if self.all_image_color_jitter > 0: + hq = kornia_color_jitter_numpy(hq, self.all_image_color_jitter) + if self.labeler: assert hq.shape[0] == hq.shape[1] # This just has not been accomodated yet. dim = hq.shape[0] @@ -273,7 +278,8 @@ if __name__ == '__main__': 'disable_flip': True, 'fixed_corruptions': ['lq_resampling', 'jpeg-medium', 'gaussian_blur', 'noise', 'color_jitter'], 'num_corrupts_per_image': 0, - 'corruption_blur_scale': 1 + 'corruption_blur_scale': 1, + 'all_image_color_jitter': .1, } ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0, batch_size=64) diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index a9520bfe..cce553e9 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -661,13 +661,18 @@ class SuperResModel(UNetModel): """ def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs): - self.num_corruptions = 0 + self.num_corruptions = num_corruptions super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs) def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs): b, _, new_height, new_width = x.shape upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") if corruption_factor is not None: + if corruption_factor.shape[1] != self.num_corruptions: + if not hasattr(self, '_corruption_factor_warning'): + print(f"Warning! Dataloader gave us {corruption_factor.shape[1]} dim but we are only processing {self.num_corruptions}. The last n corruptions will be truncated.") + self._corruption_factor_warning = True + corruption_factor = corruption_factor[:, :self.num_corruptions] corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width) else: corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device)