diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 62918b4a..55f8dd06 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -218,8 +218,9 @@ class LQGTDataset(data.Dataset): if img_GAN is not None: img_GAN = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GAN, (2, 0, 1)))).float() - lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255 - img_LQ += lq_noise + if 'lq_noise' in self.opt.keys(): + lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255 + img_LQ += lq_noise if LQ_path is None: LQ_path = GT_path diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py new file mode 100644 index 00000000..b4b2e03e --- /dev/null +++ b/codes/data/full_image_dataset.py @@ -0,0 +1,267 @@ +import random +import numpy as np +import cv2 +import torch +import torch.utils.data as data +import data.util as util +from PIL import Image, ImageOps +from io import BytesIO +import torchvision.transforms.functional as F + + +# Reads full-quality images and pulls tiles from them. Also extracts LR renderings of the full image with cues as to +# where those tiles came from. +class FullImageDataset(data.Dataset): + """ + Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. + If only GT images are provided, generate LQ images on-the-fly. + """ + def get_lq_path(self, i): + which_lq = random.randint(0, len(self.paths_LQ)-1) + return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])] + + def __init__(self, opt): + super(FullImageDataset, self).__init__() + self.opt = opt + self.data_type = 'img' + self.paths_LQ, self.paths_GT = None, None + self.sizes_LQ, self.sizes_GT = None, None + self.LQ_env, self.GT_env = None, None + self.force_multiple = self.opt['force_multiple'] if 'force_multiple' in self.opt.keys() else 1 + + self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) + if 'dataroot_LQ' in opt.keys(): + self.paths_LQ = [] + if isinstance(opt['dataroot_LQ'], list): + # Multiple LQ data sources can be given, in case there are multiple ways of corrupting a source image and + # we want the model to learn them all. + for dr_lq in opt['dataroot_LQ']: + lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, dr_lq) + self.paths_LQ.append(lq_path) + else: + lq_path, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) + self.paths_LQ.append(lq_path) + + assert self.paths_GT, 'Error: GT path is empty.' + self.random_scale_list = [1] + + def motion_blur(self, image, size, angle): + k = np.zeros((size, size), dtype=np.float32) + k[(size - 1) // 2, :] = np.ones(size, dtype=np.float32) + k = cv2.warpAffine(k, cv2.getRotationMatrix2D((size / 2 - 0.5, size / 2 - 0.5), angle, 1.0), (size, size)) + k = k * (1.0 / np.sum(k)) + return cv2.filter2D(image, -1, k) + + # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping + # offset from center is chosen on a normal probability curve. + def get_square_image(self, image): + h, w, _ = image.shape + if h == w: + return image + offset = min(np.random.normal(scale=.3), 1.0) + if h > w: + diff = h - w + center = diff // 2 + top = int(center + offset * (center - 2)) + return image[top:top+w, :, :] + else: + diff = w - h + center = diff // 2 + left = int(center + offset * (center - 2)) + return image[:, left:left+h, :] + + def pick_along_range(self, sz, r, dev): + margin_sz = sz - r + margin_center = margin_sz // 2 + return min(max(int(min(np.random.normal(scale=dev), 1.0) * margin_sz + margin_center), 0), margin_sz) + + # - Randomly extracts a square from image and resizes it to opt['target_size']. + # - Fills a mask with zeros, then places 1's where the square was extracted from. Resizes this mask and the source + # image to the target_size and returns that too. + # Notes: + # - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a + # half-normal distribution, biasing towards the target_size. + # - A biased normal distribution is also used to bias the tile selection towards the center of the source image. + def pull_tile(self, image): + target_sz = self.opt['target_size'] + h, w, _ = image.shape + possible_sizes_above_target = h - target_sz + square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.1)), 1.0)) + print("Square size: %i" % (square_size,)) + # Pick the left,top coords to draw the patch from + left = self.pick_along_range(w, square_size, .3) + top = self.pick_along_range(w, square_size, .3) + + mask = np.zeros((h, w, 1), dtype=np.float) + mask[top:top+square_size, left:left+square_size] = 1 + patch = image[top:top+square_size, left:left+square_size, :] + + patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + + return patch, image, mask + + def augment_tile(self, img_GT, img_LQ, strength=1): + scale = self.opt['scale'] + GT_size = self.opt['target_size'] + + H, W, _ = img_GT.shape + assert H >= GT_size and W >= GT_size + + LQ_size = GT_size // scale + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + + if self.opt['use_blurring']: + # Pick randomly between gaussian, motion, or no blur. + blur_det = random.randint(0, 100) + blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] + blur_magnitude = max(1, int(blur_magnitude*strength)) + if blur_det < 40: + blur_sig = int(random.randrange(0, int(blur_magnitude))) + img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig) + elif blur_det < 70: + img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360)) + + return img_GT, img_LQ + + # Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it. + def pil_augment(self, img_LQ, strength=1): + img_LQ = (img_LQ * 255).astype(np.uint8) + img_LQ = Image.fromarray(img_LQ) + if self.opt['use_compression_artifacts'] and random.random() > .25: + sub_lo = 90 * strength + sub_hi = 30 * strength + qf = random.randrange(100 - sub_lo, 100 - sub_hi) + corruption_buffer = BytesIO() + img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) + corruption_buffer.seek(0) + img_LQ = Image.open(corruption_buffer) + + if 'grayscale' in self.opt.keys() and self.opt['grayscale']: + img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') + + return img_LQ + + def __getitem__(self, index): + GT_path, LQ_path = None, None + scale = self.opt['scale'] + GT_size = self.opt['target_size'] + + # get full size image + full_path = self.paths_GT[index % len(self.paths_GT)] + img_full = util.read_img(None, full_path, None) + img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0] + img_full = self.get_square_image(img_full) + img_GT, gt_fullsize_ref, gt_mask = self.pull_tile(img_full) + + # get LQ image + if self.paths_LQ: + LQ_path = self.get_lq_path(index) + img_lq_full = util.read_img(None, LQ_path, None) + img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] + img_lq_full = self.get_square_image(img_lq_full) + img_LQ, lq_fullsize_ref, lq_mask = self.pull_tile(img_lq_full) + else: # down-sampling on-the-fly + # randomly scale during training + if self.opt['phase'] == 'train': + random_scale = random.choice(self.random_scale_list) + H_s, W_s, _ = img_GT.shape + + def _mod(n, random_scale, scale, thres): + rlt = int(n * random_scale) + rlt = (rlt // scale) * scale + return thres if rlt < thres else rlt + + H_s = _mod(H_s, random_scale, scale, GT_size) + W_s = _mod(W_s, random_scale, scale, GT_size) + img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) + if img_GT.ndim == 2: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) + + H, W, _ = img_GT.shape + + # using matlab imresize + img_LQ = util.imresize_np(img_GT, 1 / scale, True) + if img_LQ.ndim == 2: + img_LQ = np.expand_dims(img_LQ, axis=2) + lq_fullsize_ref, lq_mask = gt_fullsize_ref, gt_mask + + # Enforce force_resize constraints. + h, w, _ = img_LQ.shape + if h % self.force_multiple != 0 or w % self.force_multiple != 0: + h, w = (w - w % self.force_multiple), (h - h % self.force_multiple) + img_LQ = cv2.resize(img_LQ, (h, w)) + h *= scale + w *= scale + img_GT = cv2.resize(img_GT, (h, w)) + + if self.opt['phase'] == 'train': + img_GT, img_LQ = self.augment_tile(img_GT, img_LQ) + gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2) + lq_mask = cv2.resize(lq_mask, img_LQ.shape[0:2], interpolation=cv2.INTER_LINEAR) + + # BGR to RGB, HWC to CHW, numpy to tensor + if img_GT.shape[2] == 3: + img_GT = cv2.cvtColor(img_GT, cv2.COLOR_BGR2RGB) + img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_BGR2RGB) + lq_fullsize_ref = cv2.cvtColor(lq_fullsize_ref, cv2.COLOR_BGR2RGB) + gt_fullsize_ref = cv2.cvtColor(gt_fullsize_ref, cv2.COLOR_BGR2RGB) + + # LQ needs to go to a PIL image to perform the compression-artifact transformation. + img_LQ = self.pil_augment(img_LQ) + lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) + + img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() + gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float() + img_LQ = F.to_tensor(img_LQ) + lq_fullsize_ref = F.to_tensor(lq_fullsize_ref) + lq_mask = torch.from_numpy(np.ascontiguousarray(lq_mask)).unsqueeze(dim=0) + gt_mask = torch.from_numpy(np.ascontiguousarray(gt_mask)).unsqueeze(dim=0) + + if 'lq_noise' in self.opt.keys(): + lq_noise = torch.randn_like(img_LQ) * self.opt['lq_noise'] / 255 + img_LQ += lq_noise + lq_fullsize_ref += lq_noise + + # Apply the masks to the full images. + lq_fullsize_ref = torch.cat([lq_fullsize_ref, lq_mask], dim=0) + gt_fullsize_ref = torch.cat([gt_fullsize_ref, gt_mask], dim=0) + + if LQ_path is None: + LQ_path = GT_path + d = {'LQ': img_LQ, 'GT': img_GT, 'gt_fullsize_ref': gt_fullsize_ref, 'lq_fullsize_ref': lq_fullsize_ref, + 'LQ_path': LQ_path, 'GT_path': GT_path} + return d + + def __len__(self): + return len(self.paths_GT) + +if __name__ == '__main__': + opt = { + 'name': 'amalgam', + 'dataroot_GT': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\images'], + 'dataroot_GT_weights': [1], + 'use_flip': True, + 'use_compression_artifacts': True, + 'use_blurring': True, + 'use_rot': True, + 'lq_noise': 5, + 'target_size': 128, + 'scale': 2, + 'phase': 'train' + } + ds = FullImageDataset(opt) + import os + os.makedirs("debug", exist_ok=True) + for i in range(1000): + o = ds[i] + for k, v in o.items(): + if 'path' not in k: + if 'full' in k: + masked = v[:3, :, :] * v[3] + torchvision.utils.save_image(masked.unsqueeze(0), "debug/%i_%s_masked.png" % (i, k)) + v = v[:3, :, :] + import torchvision + torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k)) \ No newline at end of file diff --git a/codes/data_scripts/extract_subimages.py b/codes/data_scripts/extract_subimages.py index 175502f6..7d0cbd60 100644 --- a/codes/data_scripts/extract_subimages.py +++ b/codes/data_scripts/extract_subimages.py @@ -20,12 +20,12 @@ def main(): # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # compression time. If read raw images during training, use 0 for faster IO speed. if mode == 'single': - opt['input_folder'] = 'F:\\4k6k\\datasets\\flickr2k\\Flickr2K_HR' - opt['save_folder'] = 'F:\\4k6k\\datasets\\flickr2k\\1024px' - opt['crop_sz'] = 1024 # the size of each sub-image - opt['step'] = 880 # step of the sliding crop window - opt['thres_sz'] = 240 # size threshold - opt['resize_final_img'] = 1 + opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\images' + opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\square_context' + opt['crop_sz'] = 4096 # the size of each sub-image + opt['step'] = 4096 # step of the sliding crop window + opt['thres_sz'] = 256 # size threshold + opt['resize_final_img'] = .5 opt['only_resize'] = False extract_single(opt, split_img) elif mode == 'pair': @@ -93,6 +93,8 @@ def extract_single(opt, split_img=False): pool = Pool(opt['n_thread']) for path in img_list: + # If this fails, change it and the imwrite below to the write extension. + assert ".jpg" in path if split_img: pool.apply_async(worker, args=(path, opt, True, False), callback=update) pool.apply_async(worker, args=(path, opt, True, True), callback=update) @@ -122,7 +124,6 @@ def worker(path, opt, split_mode=False, left_img=True): # Uncomment to filter any image that doesnt meet a threshold size. if min(h,w) < 1024: return - left = 0 right = w if split_mode: @@ -163,8 +164,6 @@ def worker(path, opt, split_mode=False, left_img=True): else: crop_img = img[x:x + crop_sz, y:y + crop_sz, :] crop_img = np.ascontiguousarray(crop_img) - # If this fails, change it and the imwrite below to the write extension. - assert ".png" in img_name if 'resize_final_img' in opt.keys(): # Resize too. resize_factor = opt['resize_final_img'] @@ -173,7 +172,7 @@ def worker(path, opt, split_mode=False, left_img=True): crop_img = cv2.resize(crop_img, dsize, interpolation = cv2.INTER_AREA) cv2.imwrite( osp.join(opt['save_folder'], - img_name.replace('.png', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img, + img_name.replace('.jpg', '_l{:05d}_s{:03d}.png'.format(left, index))), crop_img, [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) return 'Processing {:s} ...'.format(img_name)