diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 16f885d2..4c077eed 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -1,10 +1,29 @@ import random +from math import cos, pi + import cv2 import numpy as np from data.util import read_img from PIL import Image from io import BytesIO + +# Feeds a random uniform through a cosine distribution to slightly bias corruptions towards "uncorrupted". +# Return is on [0,1] with a bias towards 0. +def get_rand(): + r = random.random() + return 1 - cos(r * pi / 2) + +# Get a rough visualization of the above distribution. (Y-axis is meaningless, just spreads data) +''' +if __name__ == '__main__': + import numpy as np + import matplotlib.pyplot as plt + data = np.asarray([get_rand() for _ in range(5000)]) + plt.plot(data, np.random.uniform(size=(5000,)), 'x') + plt.show() +''' + # Performs image corruption on a list of images from a configurable set of corruption # options. class ImageCorruptor: @@ -16,7 +35,7 @@ class ImageCorruptor: return self.random_corruptions = opt['random_corruptions'] if 'random_corruptions' in opt.keys() else [] - def corrupt_images(self, imgs): + def corrupt_images(self, imgs, return_entropy=False): if self.num_corrupts == 0 and not self.fixed_corruptions: return imgs @@ -24,51 +43,45 @@ class ImageCorruptor: augmentations = [] else: augmentations = random.choices(self.random_corruptions, k=self.num_corrupts) - # Source of entropy, which should be used across all images. - rand_int_f = random.randint(1, 999999) - rand_int_a = random.randint(1, 999999) + # Sources of entropy corrupted_imgs = [] + entropy = [] applied_augs = augmentations + self.fixed_corruptions for img in imgs: for aug in augmentations: - img = self.apply_corruption(img, aug, rand_int_a, applied_augs) + r = get_rand() + img = self.apply_corruption(img, aug, r, applied_augs) for aug in self.fixed_corruptions: - img = self.apply_corruption(img, aug, rand_int_f, applied_augs) + r = get_rand() + img = self.apply_corruption(img, aug, r, applied_augs) + entropy.append(r) corrupted_imgs.append(img) - return corrupted_imgs + if return_entropy: + return corrupted_imgs, entropy + else: + return corrupted_imgs - def apply_corruption(self, img, aug, rand_int, applied_augmentations): + def apply_corruption(self, img, aug, rand_val, applied_augmentations): if 'color_quantization' in aug: # Color quantization - quant_div = 2 ** ((rand_int % 3) + 2) + quant_div = 2 ** (int(rand_val * 10 / 3) + 2) img = img * 255 img = (img // quant_div) * quant_div img = img / 255 elif 'gaussian_blur' in aug: - # Gaussian Blur - if aug == 'gaussian_blur_3': - kernel = 3 - elif aug == 'gaussian_blur_5': - kernel = 5 - else: - kernel = 2 * self.blur_scale * (rand_int % 3) + 1 - img = cv2.GaussianBlur(img, (kernel, kernel), 3) + img = cv2.GaussianBlur(img, (0,0), rand_val*1.5) elif 'motion_blur' in aug: # Motion blur - intensity = self.blur_scale * (rand_int % 3) + 1 - angle = (rand_int // 3) % 360 + intensity = self.blur_scale * rand_val * 3 + 1 + angle = random.randint(0,360) k = np.zeros((intensity, intensity), dtype=np.float32) k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32) k = cv2.warpAffine(k, cv2.getRotationMatrix2D((intensity / 2 - 0.5, intensity / 2 - 0.5), angle, 1.0), (intensity, intensity)) k = k * (1.0 / np.sum(k)) img = cv2.filter2D(img, -1, k) - elif 'smooth_blur' in aug: - # Smooth blur - kernel = 2 * self.blur_scale * (rand_int % 3) + 1 - img = cv2.blur(img, ksize=(kernel, kernel)) elif 'block_noise' in aug: # Large distortion blocks in part of an img, such as is used to mask out a face. pass @@ -78,7 +91,7 @@ class ImageCorruptor: if 'lq_resampling4x' == aug: scale = 4 interpolation_modes = [cv2.INTER_NEAREST, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_LANCZOS4] - mode = rand_int % len(interpolation_modes) + mode = random.randint(0,4) % len(interpolation_modes) # Downsample first, then upsample using the random mode. img = cv2.resize(img, dsize=(img.shape[1]//scale, img.shape[0]//scale), interpolation=cv2.INTER_NEAREST) img = cv2.resize(img, dsize=(img.shape[1]*scale, img.shape[0]*scale), interpolation=mode) @@ -96,7 +109,7 @@ class ImageCorruptor: if 'noise-5' == aug: noise_intensity = 5 / 255.0 else: - noise_intensity = (rand_int % 4 + 2) / 255.0 # Between 1-4 + noise_intensity = (rand_val*4 + 2) / 255.0 img += np.random.randn(*img.shape) * noise_intensity elif 'jpeg' in aug: if 'noise' not in applied_augmentations and 'noise-5' not in applied_augmentations: @@ -118,7 +131,7 @@ class ImageCorruptor: else: raise NotImplementedError("specified jpeg corruption doesn't exist") # JPEG compression - qf = (rand_int % range + lo) + qf = (int((1-rand_val)*range) + lo) # Use PIL to perform a mock compression to a data buffer, then swap back to cv2. img = (img * 255).astype(np.uint8) img = Image.fromarray(img) @@ -129,7 +142,7 @@ class ImageCorruptor: img = read_img("buffer", jpeg_img_bytes, rgb=True) elif 'saturation' in aug: # Lightening / saturation - saturation = float(rand_int % 10) * .03 + saturation = rand_val * .3 img = np.clip(img + saturation, a_max=1, a_min=0) elif 'none' not in aug: raise NotImplementedError("Augmentation doesn't exist") diff --git a/codes/data/image_folder_dataset.py b/codes/data/image_folder_dataset.py index a0621eae..c970c7ef 100644 --- a/codes/data/image_folder_dataset.py +++ b/codes/data/image_folder_dataset.py @@ -112,14 +112,14 @@ class ImageFolderDataset: local_scale = local_scale // special_factor else: hs = [h.copy() for h in hs] - hs = self.corruptor.corrupt_images(hs) + hs, ent = self.corruptor.corrupt_images(hs, return_entropy=True) for hq in hs: h, w, _ = hq.shape ls.append(cv2.resize(hq, (h // local_scale, w // local_scale), interpolation=cv2.INTER_AREA)) # Corrupt the LQ image (only in eval mode) if not self.corrupt_before_downsize: - ls = self.corruptor.corrupt_images(ls) - return ls + ls, ent = self.corruptor.corrupt_images(ls, return_entropy=True) + return ls, ent def __len__(self): return self.len @@ -184,9 +184,10 @@ class ImageFolderDataset: out_dict['alt_hq'] = alt_hq if not self.skip_lq: - lqs = self.synthesize_lq(for_lq) + lqs, ent = self.synthesize_lq(for_lq) ls = lqs[0] out_dict['lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(ls, (2, 0, 1)))).float() + out_dict['corruption_entropy'] = torch.tensor(ent) if len(lqs) > 1: alt_lq = lqs[1] out_dict['alt_lq'] = torch.from_numpy(np.ascontiguousarray(np.transpose(alt_lq, (2, 0, 1)))).float() @@ -215,25 +216,25 @@ class ImageFolderDataset: if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['E:\\4k6k\\datasets\\ns_images\\256_unsupervised'], + 'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], 'weights': [1], 'target_size': 256, 'force_multiple': 1, 'scale': 2, 'corrupt_before_downsize': True, - 'fetch_alt_image': True, + 'fetch_alt_image': False, 'disable_flip': True, - 'fixed_corruptions': [ 'jpeg-broad' ], + 'fixed_corruptions': [ 'jpeg-broad', 'gaussian_blur' ], 'num_corrupts_per_image': 0, 'corruption_blur_scale': 0 } - ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=2) + ds = DataLoader(ImageFolderDataset(opt), shuffle=True, num_workers=0) import os output_path = 'E:\\4k6k\\datasets\\ns_images\\128_unsupervised' os.makedirs(output_path, exist_ok=True) for i, d in tqdm(enumerate(ds)): lq = d['lq'] - torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png') + #torchvision.utils.save_image(lq[:,:,16:-16,:], f'{output_path}\\{i+500000}.png') if i >= 200000: break \ No newline at end of file diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index f58219b3..a9520bfe 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -660,12 +660,18 @@ class SuperResModel(UNetModel): Expects an extra kwarg `low_res` to condition on a low-resolution image. """ - def __init__(self, image_size, in_channels, *args, **kwargs): - super().__init__(image_size, in_channels * 2, *args, **kwargs) + def __init__(self, image_size, in_channels, num_corruptions=0, *args, **kwargs): + self.num_corruptions = 0 + super().__init__(image_size, in_channels * 2 + num_corruptions, *args, **kwargs) - def forward(self, x, timesteps, low_res=None, **kwargs): - _, _, new_height, new_width = x.shape + def forward(self, x, timesteps, low_res=None, corruption_factor=None, **kwargs): + b, _, new_height, new_width = x.shape upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + if corruption_factor is not None: + corruption_factor = corruption_factor.view(b, -1, 1, 1).repeat(1, 1, new_height, new_width) + else: + corruption_factor = torch.zeros((b, self.num_corruptions, new_height, new_width), dtype=torch.float, device=x.device) + upsampled = torch.cat([upsampled, corruption_factor], dim=1) x = th.cat([x, upsampled], dim=1) res = super().forward(x, timesteps, **kwargs) return res