From 6c6e82406eccbb5b96aae358a41dda455f74f111 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 7 Jun 2021 09:13:54 -0600 Subject: [PATCH] Pass a corruption factor through the dataset into the upsampling network The intuition is this will help guide the network to make better informed decisions about how it performs upsampling based on how it perceives the underlying content. (I'm giving up on letting networks detect their own quality - I'm not convinced it is actually feasible) --- codes/data/image_corruptor.py | 67 ++++++++++++++---------- codes/data/image_folder_dataset.py | 19 +++---- codes/models/diffusion/unet_diffusion.py | 14 +++-- 3 files changed, 60 insertions(+), 40 deletions(-) diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 16f885d2..4c077eed 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -1,10 +1,29 @@ import random +from math import cos, pi + import cv2 import numpy as np from data.util import read_img from PIL import Image from io import BytesIO + +# Feeds a random uniform through a cosine distribution to slightly bias corruptions towards "uncorrupted". +# Return is on [0,1] with a bias towards 0. +def get_rand(): + r = random.random() + return 1 - cos(r * pi / 2) + +# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data) +''' +if __name__ == '__main__': + import numpy as np + import matplotlib.pyplot as plt + data = np.asarray([get_rand() for _ in range(5000)]) + plt.plot(data, np.random.uniform(size=(5000,)), 'x') + plt.show() +''' + # Performs image corruption on a list of images from a configurable set of corruption # options. class ImageCorruptor: @@ -16,7 +35,7 @@ class ImageCorruptor: return self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] - def corrupt_images(self, imgs): + def corrupt_images(self, imgs, return_entropy=False): if self.num_corrupts == 0 and not self.fixed_corruptions: return imgs @@ -24,51 +43,45 @@ class ImageCorruptor: augmentations = [] else: augmentations = random.choices(self.random_corruptions, k=self.num_corrupts) - # Source of entropy, which should be used across all images. - rand_int_f = random.randint(1, 999999) - rand_int_a = random.randint(1, 999999) + # Sources of entropy corrupted_imgs = [] + entropy = [] applied_augs = augmentations + self.fixed_corruptions for img in imgs: for aug in augmentations: - img = self.apply_corruption(img, aug, rand_int_a, applied_augs) + r = get_rand() + img = self.apply_corruption(img, aug, r, applied_augs) for aug in self.fixed_corruptions: - img = self.apply_corruption(img, aug, rand_int_f, applied_augs) + r = get_rand() + img = self.apply_corruption(img, aug, r, applied_augs) + entropy.append(r) corrupted_imgs.append(img) - return corrupted_imgs + if return_entropy: + return corrupted_imgs, entropy + else: + return corrupted_imgs - def apply_corruption(self, img, aug, rand_int, applied_augmentations): + def apply_corruption(self, img, aug, rand_val, applied_augmentations): if 'color_quantization' in aug: # Color quantization - quant_div = 2 ** ((rand_int % 3) + 2) + quant_div = 2 ** (int(rand_val * 10 / 3) + 2) img = img * 255 img = (img // quant_div) * quant_div img = img / 255 elif 'gaussian_blur' in aug: - # Gaussian Blur - if aug == 'gaussian_blur_3': - kernel = 3 - elif aug == 'gaussian_blur_5': - kernel = 5 - else: - kernel = 2 * self.blur_scale * (rand_int % 3) + 1 - img = cv2.GaussianBlur(img, (kernel, kernel), 3) + img = cv2.GaussianBlur(img, (0,0), rand_val*1.5) elif 'motion_blur' in aug: # Motion blur - intensity = self.blur_scale * (rand_int % 3) + 1 - angle = (rand_int // 3) % 360 + intensity = self.blur_scale * rand_val * 3 + 1 + angle = random.randint(0,360) k = np.zeros((intensity, intensity), dtype=np.float32) k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32) k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0), (intensity, intensity)) k = k * (1.0 / np.sum(k)) img = cv2.filter2D(img, -1, k) - elif 'smooth_blur' in aug: - # Smooth blur - kernel = 2 * self.blur_scale * (rand_int % 3) + 1 - img = cv2.blur(img, ksize=(kernel, kernel)) elif 'block_noise' in aug: # Large distortion blocks in part of an img, such as is used to mask out a face. pass @@ -78,7 +91,7 @@ class ImageCorruptor: if 'lq_resampling4x' == aug: scale = 4 interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] - mode = rand_int % len(interpolation_modes) + mode = random.randint(0,4) % len(interpolation_modes) # Downsample first, then upsample using the random mode. img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST) img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode) @@ -96,7 +109,7 @@ class ImageCorruptor: if 'noise-5' == aug: noise_intensity = 5 / 255.0 else: - noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4 + noise_intensity = (rand_val*4 + 2) / 255.0 img += np.random.randn(*img.shape) * noise_intensity elif 'jpeg' in aug: if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations: @@ -118,7 +131,7 @@ class ImageCorruptor: else: raise NotImplementedError("specified jpeg corruption doesn't exist") # JPEG compression - qf = (rand_int % range + lo) + qf = (int((1-rand_val)*range) + lo) # Use PIL to perform a mock compression to a data buffer, then swap back to cv2. img = (img * 255).astype(np.uint8) img = Image.fromarray(img) @@ -129,7 +142,7 @@ class ImageCorruptor: img = read_img("buffer", jpeg_img_bytes, rgb=True) elif 'saturation' in aug: # Lightening / saturation - saturation = float(rand_int % 10) * .03 + saturation = rand_val * .3 img = np.clip(img + saturation, a_max=1, a_min=0) elif 'none' not in aug: raise NotImplementedError("Augmentation doesn't exist") diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index a0621eae..c970c7ef 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -112,14 +112,14 @@ class ImageFolderDataset: local_scale = local_scale // special_factor else: hs = [h.copy() for h in hs] - hs = self.corruptor.corrupt_images(hs) + hs, ent = self.corruptor.corrupt_images(hs, return_entropy=True) for hq in hs: h, w, _ = hq.shape ls.append(cv2.resize(hq, (h // local_scale, w // local_scale), interpolation=cv2.INTER_AREA)) # Corrupt the LQ image (only in eval mode) if not self.corrupt_before_downsize: - ls = self.corruptor.corrupt_images(ls) - return ls + ls, ent = self.corruptor.corrupt_images(ls, return_entropy=True) + return ls, ent def __len__(self): return self.len @@ -184,9 +184,10 @@ class ImageFolderDataset: out_dict['alt_hq'] = alt_hq if not self.skip_lq: - lqs = self.synthesize_lq(for_lq) + lqs, ent = self.synthesize_lq(for_lq) ls = lqs[0] out_dict['lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(ls, (2, 0, 1)))).float() + out_dict['corruption_entropy'] = torch.tensor(ent) if len(lqs) > 1: alt_lq = lqs[1] out_dict['alt_lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq, (2, 0, 1)))).float() @@ -215,25 +216,25 @@ class ImageFolderDataset: if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['E:\\4k6k\\datasets\\ns_images\\256_unsupervised'], + 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], 'weights': [1], 'target_size': 256, 'force_multiple': 1, 'scale': 2, 'corrupt_before_downsize': True, - 'fetch_alt_image': True, + 'fetch_alt_image': False, 'disable_flip': True, - 'fixed_corruptions': [ 'jpeg-broad' ], + 'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ], 'num_corrupts_per_image': 0, 'corruption_blur_scale': 0 } - ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=2) + ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0) import os output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised' os.makedirs(output_path, exist_ok=True) for i, d in tqdm(enumerate(ds)): lq = d['lq'] - torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png') + #torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png') if i >= 200000: break \ No newline at end of file diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index f58219b3..a9520bfe 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -660,12 +660,18 @@ class SuperResModel(UNetModel): Expects an extra kwarg `low_res` to condition on a low-resolution image. """ - def __init__(self, image_size, in_channels, *args, **kwargs): - super().__init__(image_size, in_channels * 2, *args, **kwargs) + def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs): + self.num_corruptions = 0 + super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs) - def forward(self, x, timesteps, low_res=None, **kwargs): - _, _, new_height, new_width = x.shape + def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs): + b, _, new_height, new_width = x.shape upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + if corruption_factor is not None: + corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width) + else: + corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) + upsampled = torch.cat([upsampled, corruption_factor], dim=1) x = th.cat([x, upsampled], dim=1) res = super().forward(x, timesteps, **kwargs) return res