From 4f75cf0f027d53e362f49f4b9fdd1ddeb28fd745 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 18 Sep 2020 09:50:43 -0600 Subject: [PATCH] Revert "Full image dataset operates on lists" Going with an entirely new dataset instead.. This reverts commit 36ec32bf11583998a2ceffae287649971b0ea5ba. --- codes/data/full_image_dataset.py | 195 +++++++++++++++++-------------- 1 file changed, 110 insertions(+), 85 deletions(-) diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index b14943ff..5f9e9c92 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -54,24 +54,21 @@ class FullImageDataset(data.Dataset): # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # offset from center is chosen on a normal probability curve. - def get_square_image(self, images): - h, w, _ = images[0].shape + def get_square_image(self, image): + h, w, _ = image.shape if h == w: - return images + return image offset = max(min(np.random.normal(scale=.3), 1.0), -1.0) - res = [] - for image in images: - if h > w: - diff = h - w - center = diff // 2 - top = int(center + offset * (center - 2)) - res.append(image[top:top+w, :, :]) - else: - diff = w - h - center = diff // 2 - left = int(center + offset * (center - 2)) - res.append(image[:, left:left+h, :]) - return res + if h > w: + diff = h - w + center = diff // 2 + top = int(center + offset * (center - 2)) + return image[top:top+w, :, :] + else: + diff = w - h + center = diff // 2 + left = int(center + offset * (center - 2)) + return image[:, left:left+h, :] def pick_along_range(self, sz, r, dev): margin_sz = sz - r @@ -93,88 +90,120 @@ class FullImageDataset(data.Dataset): # - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a # half-normal distribution, biasing towards the target_size. # - A biased normal distribution is also used to bias the tile selection towards the center of the source image. - def pull_tile(self, images, lq=False): + def pull_tile(self, image, lq=False): if lq: target_sz = self.opt['min_tile_size'] // self.opt['scale'] else: target_sz = self.opt['min_tile_size'] - h, w, _ = images[0].shape + h, w, _ = image.shape possible_sizes_above_target = h - target_sz square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0)) # Pick the left,top coords to draw the patch from left = self.pick_along_range(w, square_size, .3) top = self.pick_along_range(w, square_size, .3) - patches, ims, masks, centers = [], [], [], [] - for image in images: - mask = np.zeros((h, w, 1), dtype=image.dtype) - mask[top:top+square_size, left:left+square_size] = 1 - patch = image[top:top+square_size, left:left+square_size, :] - center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) + mask = np.zeros((h, w, 1), dtype=image.dtype) + mask[top:top+square_size, left:left+square_size] = 1 + patch = image[top:top+square_size, left:left+square_size, :] + center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) - patches.append(cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) - ims.append(cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) - masks.append(cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) - centers.append(self.resize_point(center, (h, w), ims[-1].shape[:2])) + patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) + center = self.resize_point(center, (h, w), image.shape[:2]) - return patches, ims, masks, centers + return patch, image, mask, center - def augment_tile(self, imgs_GT, imgs_LQ, strength=1): + def augment_tile(self, img_GT, img_LQ, strength=1): scale = self.opt['scale'] GT_size = self.opt['target_size'] - H, W, _ = imgs_GT[0].shape + H, W, _ = img_GT.shape assert H >= GT_size and W >= GT_size - # Establish random variables. - blur_det = random.randint(0, 100) - blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] - blur_magnitude = max(1, int(blur_magnitude*strength)) - blur_sig = int(random.randrange(0, int(blur_magnitude))) - blur_direction = random.randint(0, 360) + LQ_size = GT_size // scale + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) - lqs, gts = [], [] - for img_LQ, img_GT in zip(imgs_LQ, imgs_GT): - LQ_size = GT_size // scale - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - gts.append(cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)) + if self.opt['use_blurring']: + # Pick randomly between gaussian, motion, or no blur. + blur_det = random.randint(0, 100) + blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] + blur_magnitude = max(1, int(blur_magnitude*strength)) + if blur_det < 40: + blur_sig = int(random.randrange(0, int(blur_magnitude))) + img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig) + elif blur_det < 70: + img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360)) - if self.opt['use_blurring']: - # Pick randomly between gaussian, motion, or no blur. - if blur_det < 40: - lqs.append(cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)) - elif blur_det < 70: - lqs.append(self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), blur_direction)) - else: - lqs.append(img_LQ) - else: - lqs.append(img_LQ) - - return gts, lqs + return img_GT, img_LQ # Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it. - def pil_augment(self, imgs_LQ, strength=1): - # Compute random variables - do_jpg = random.random() - sub_lo = 90 * strength - sub_hi = 30 * strength - qf = random.randrange(100 - sub_lo, 100 - sub_hi) + def pil_augment(self, img_LQ, strength=1): + img_LQ = (img_LQ * 255).astype(np.uint8) + img_LQ = Image.fromarray(img_LQ) + if self.opt['use_compression_artifacts'] and random.random() > .25: + sub_lo = 90 * strength + sub_hi = 30 * strength + qf = random.randrange(100 - sub_lo, 100 - sub_hi) + corruption_buffer = BytesIO() + img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) + corruption_buffer.seek(0) + img_LQ = Image.open(corruption_buffer) - ims_out = [] - for img_LQ in imgs_LQ: - img_LQ = (img_LQ * 255).astype(np.uint8) - img_LQ = Image.fromarray(img_LQ) - if self.opt['use_compression_artifacts'] and do_jpg > .25: - corruption_buffer = BytesIO() - img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) - corruption_buffer.seek(0) - img_LQ = Image.open(corruption_buffer) + if 'grayscale' in self.opt.keys() and self.opt['grayscale']: + img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') - if 'grayscale' in self.opt.keys() and self.opt['grayscale']: - img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') - ims_out.append(img_LQ) + return img_LQ - return ims_out + def perform_random_hr_augment(self, image, aug_code=None, augmentations=1): + if aug_code is None: + aug_code = [random.randint(0, 10) for _ in range(augmentations)] + else: + assert augmentations == 1 + aug_code = [aug_code] + if 0 in aug_code: + # Color quantization + pass + elif 1 in aug_code: + # Gaussian Blur (point or motion) + blur_magnitude = 3 + blur_sig = int(random.randrange(0, int(blur_magnitude))) + image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig) + elif 2 in aug_code: + # Median Blur + image = cv2.medianBlur(image, 3) + elif 3 in aug_code: + # Motion blur + image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360)) + elif 4 in aug_code: + # Smooth blur + image = cv2.blur(image, ksize=3) + elif 5 in aug_code: + # Block noise + pass + elif 6 in aug_code: + # Bicubic LR->HR + pass + elif 7 in aug_code: + # Linear compression distortion + pass + elif 8 in aug_code: + # Interlacing distortion + pass + elif 9 in aug_code: + # Chromatic aberration + pass + elif 10 in aug_code: + # Noise + pass + elif 11 in aug_code: + # JPEG compression + pass + elif 12 in aug_code: + # Lightening / saturation + pass + return image def __getitem__(self, index): scale = self.opt['scale'] @@ -186,9 +215,8 @@ class FullImageDataset(data.Dataset): img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0] if self.opt['phase'] == 'train': img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0] - img_full = self.get_square_image([img_full])[0] - imgs_GT, gt_fullsize_refs, gt_masks, gt_centers = self.pull_tile([img_full]) - img_GT, gt_fullsize_ref, gt_mask, gt_center = imgs_GT[0], gt_fullsize_refs[0], gt_masks[0], gt_centers[0] + img_full = self.get_square_image(img_full) + img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full) else: img_GT, gt_fullsize_ref = img_full, img_full gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype) @@ -200,9 +228,8 @@ class FullImageDataset(data.Dataset): LQ_path = self.get_lq_path(index) img_lq_full = util.read_img(None, LQ_path, None) img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] - img_lq_full = self.get_square_image([img_lq_full])[0] - imgs_LQ, lq_fullsize_refs, lq_masks, lq_centers = self.pull_tile([img_lq_full], lq=True) - img_LQ, lq_fullsize_ref, lq_mask, lq_center = imgs_LQ[0], lq_fullsize_refs[0], lq_masks[0], lq_centers[0] + img_lq_full = self.get_square_image(img_lq_full) + img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True) else: # down-sampling on-the-fly # randomly scale during training if self.opt['phase'] == 'train': @@ -247,10 +274,8 @@ class FullImageDataset(data.Dataset): gt_fullsize_ref = gt_fullsize_ref[:h, :w, :] if self.opt['phase'] == 'train': - imgs_GT, imgs_LQ = self.augment_tile([img_GT], [img_LQ]) - img_GT, img_LQ = imgs_GT[0], imgs_LQ[0] - gt_fullsize_refs, lq_fullsize_refs = self.augment_tile([gt_fullsize_ref], [lq_fullsize_ref], strength=.2) - gt_fullsize_ref, lq_fullsize_ref = gt_fullsize_refs[0], lq_fullsize_refs[0] + img_GT, img_LQ = self.augment_tile(img_GT, img_LQ) + gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2) # Scale masks. lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) @@ -269,8 +294,8 @@ class FullImageDataset(data.Dataset): # LQ needs to go to a PIL image to perform the compression-artifact transformation. if self.opt['phase'] == 'train': - img_LQ = self.pil_augment([img_LQ])[0] - lq_fullsize_ref = self.pil_augment([lq_fullsize_ref], strength=.2)[0] + img_LQ = self.pil_augment(img_LQ) + lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()