From 36ec32bf11583998a2ceffae287649971b0ea5ba Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 18 Sep 2020 09:50:26 -0600 Subject: [PATCH] Full image dataset operates on lists --- codes/data/full_image_dataset.py | 195 ++++++++++++++----------------- 1 file changed, 85 insertions(+), 110 deletions(-) diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index 5f9e9c92..b14943ff 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -54,21 +54,24 @@ class FullImageDataset(data.Dataset): # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # offset from center is chosen on a normal probability curve. - def get_square_image(self, image): - h, w, _ = image.shape + def get_square_image(self, images): + h, w, _ = images[0].shape if h == w: - return image + return images offset = max(min(np.random.normal(scale=.3), 1.0), -1.0) - if h > w: - diff = h - w - center = diff // 2 - top = int(center + offset * (center - 2)) - return image[top:top+w, :, :] - else: - diff = w - h - center = diff // 2 - left = int(center + offset * (center - 2)) - return image[:, left:left+h, :] + res = [] + for image in images: + if h > w: + diff = h - w + center = diff // 2 + top = int(center + offset * (center - 2)) + res.append(image[top:top+w, :, :]) + else: + diff = w - h + center = diff // 2 + left = int(center + offset * (center - 2)) + res.append(image[:, left:left+h, :]) + return res def pick_along_range(self, sz, r, dev): margin_sz = sz - r @@ -90,120 +93,88 @@ class FullImageDataset(data.Dataset): # - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a # half-normal distribution, biasing towards the target_size. # - A biased normal distribution is also used to bias the tile selection towards the center of the source image. - def pull_tile(self, image, lq=False): + def pull_tile(self, images, lq=False): if lq: target_sz = self.opt['min_tile_size'] // self.opt['scale'] else: target_sz = self.opt['min_tile_size'] - h, w, _ = image.shape + h, w, _ = images[0].shape possible_sizes_above_target = h - target_sz square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0)) # Pick the left,top coords to draw the patch from left = self.pick_along_range(w, square_size, .3) top = self.pick_along_range(w, square_size, .3) - mask = np.zeros((h, w, 1), dtype=image.dtype) - mask[top:top+square_size, left:left+square_size] = 1 - patch = image[top:top+square_size, left:left+square_size, :] - center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) + patches, ims, masks, centers = [], [], [], [] + for image in images: + mask = np.zeros((h, w, 1), dtype=image.dtype) + mask[top:top+square_size, left:left+square_size] = 1 + patch = image[top:top+square_size, left:left+square_size, :] + center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) - patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) - image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) - mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR) - center = self.resize_point(center, (h, w), image.shape[:2]) + patches.append(cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) + ims.append(cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) + masks.append(cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) + centers.append(self.resize_point(center, (h, w), ims[-1].shape[:2])) - return patch, image, mask, center + return patches, ims, masks, centers - def augment_tile(self, img_GT, img_LQ, strength=1): + def augment_tile(self, imgs_GT, imgs_LQ, strength=1): scale = self.opt['scale'] GT_size = self.opt['target_size'] - H, W, _ = img_GT.shape + H, W, _ = imgs_GT[0].shape assert H >= GT_size and W >= GT_size - LQ_size = GT_size // scale - img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) - img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) + # Establish random variables. + blur_det = random.randint(0, 100) + blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] + blur_magnitude = max(1, int(blur_magnitude*strength)) + blur_sig = int(random.randrange(0, int(blur_magnitude))) + blur_direction = random.randint(0, 360) - if self.opt['use_blurring']: - # Pick randomly between gaussian, motion, or no blur. - blur_det = random.randint(0, 100) - blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] - blur_magnitude = max(1, int(blur_magnitude*strength)) - if blur_det < 40: - blur_sig = int(random.randrange(0, int(blur_magnitude))) - img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig) - elif blur_det < 70: - img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360)) + lqs, gts = [], [] + for img_LQ, img_GT in zip(imgs_LQ, imgs_GT): + LQ_size = GT_size // scale + img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) + gts.append(cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)) - return img_GT, img_LQ + if self.opt['use_blurring']: + # Pick randomly between gaussian, motion, or no blur. + if blur_det < 40: + lqs.append(cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)) + elif blur_det < 70: + lqs.append(self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), blur_direction)) + else: + lqs.append(img_LQ) + else: + lqs.append(img_LQ) + + return gts, lqs # Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it. - def pil_augment(self, img_LQ, strength=1): - img_LQ = (img_LQ * 255).astype(np.uint8) - img_LQ = Image.fromarray(img_LQ) - if self.opt['use_compression_artifacts'] and random.random() > .25: - sub_lo = 90 * strength - sub_hi = 30 * strength - qf = random.randrange(100 - sub_lo, 100 - sub_hi) - corruption_buffer = BytesIO() - img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) - corruption_buffer.seek(0) - img_LQ = Image.open(corruption_buffer) + def pil_augment(self, imgs_LQ, strength=1): + # Compute random variables + do_jpg = random.random() + sub_lo = 90 * strength + sub_hi = 30 * strength + qf = random.randrange(100 - sub_lo, 100 - sub_hi) - if 'grayscale' in self.opt.keys() and self.opt['grayscale']: - img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') + ims_out = [] + for img_LQ in imgs_LQ: + img_LQ = (img_LQ * 255).astype(np.uint8) + img_LQ = Image.fromarray(img_LQ) + if self.opt['use_compression_artifacts'] and do_jpg > .25: + corruption_buffer = BytesIO() + img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True) + corruption_buffer.seek(0) + img_LQ = Image.open(corruption_buffer) - return img_LQ + if 'grayscale' in self.opt.keys() and self.opt['grayscale']: + img_LQ = ImageOps.grayscale(img_LQ).convert('RGB') + ims_out.append(img_LQ) - def perform_random_hr_augment(self, image, aug_code=None, augmentations=1): - if aug_code is None: - aug_code = [random.randint(0, 10) for _ in range(augmentations)] - else: - assert augmentations == 1 - aug_code = [aug_code] - if 0 in aug_code: - # Color quantization - pass - elif 1 in aug_code: - # Gaussian Blur (point or motion) - blur_magnitude = 3 - blur_sig = int(random.randrange(0, int(blur_magnitude))) - image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig) - elif 2 in aug_code: - # Median Blur - image = cv2.medianBlur(image, 3) - elif 3 in aug_code: - # Motion blur - image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360)) - elif 4 in aug_code: - # Smooth blur - image = cv2.blur(image, ksize=3) - elif 5 in aug_code: - # Block noise - pass - elif 6 in aug_code: - # Bicubic LR->HR - pass - elif 7 in aug_code: - # Linear compression distortion - pass - elif 8 in aug_code: - # Interlacing distortion - pass - elif 9 in aug_code: - # Chromatic aberration - pass - elif 10 in aug_code: - # Noise - pass - elif 11 in aug_code: - # JPEG compression - pass - elif 12 in aug_code: - # Lightening / saturation - pass - return image + return ims_out def __getitem__(self, index): scale = self.opt['scale'] @@ -215,8 +186,9 @@ class FullImageDataset(data.Dataset): img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0] if self.opt['phase'] == 'train': img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0] - img_full = self.get_square_image(img_full) - img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full) + img_full = self.get_square_image([img_full])[0] + imgs_GT, gt_fullsize_refs, gt_masks, gt_centers = self.pull_tile([img_full]) + img_GT, gt_fullsize_ref, gt_mask, gt_center = imgs_GT[0], gt_fullsize_refs[0], gt_masks[0], gt_centers[0] else: img_GT, gt_fullsize_ref = img_full, img_full gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype) @@ -228,8 +200,9 @@ class FullImageDataset(data.Dataset): LQ_path = self.get_lq_path(index) img_lq_full = util.read_img(None, LQ_path, None) img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] - img_lq_full = self.get_square_image(img_lq_full) - img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True) + img_lq_full = self.get_square_image([img_lq_full])[0] + imgs_LQ, lq_fullsize_refs, lq_masks, lq_centers = self.pull_tile([img_lq_full], lq=True) + img_LQ, lq_fullsize_ref, lq_mask, lq_center = imgs_LQ[0], lq_fullsize_refs[0], lq_masks[0], lq_centers[0] else: # down-sampling on-the-fly # randomly scale during training if self.opt['phase'] == 'train': @@ -274,8 +247,10 @@ class FullImageDataset(data.Dataset): gt_fullsize_ref = gt_fullsize_ref[:h, :w, :] if self.opt['phase'] == 'train': - img_GT, img_LQ = self.augment_tile(img_GT, img_LQ) - gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2) + imgs_GT, imgs_LQ = self.augment_tile([img_GT], [img_LQ]) + img_GT, img_LQ = imgs_GT[0], imgs_LQ[0] + gt_fullsize_refs, lq_fullsize_refs = self.augment_tile([gt_fullsize_ref], [lq_fullsize_ref], strength=.2) + gt_fullsize_ref, lq_fullsize_ref = gt_fullsize_refs[0], lq_fullsize_refs[0] # Scale masks. lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) @@ -294,8 +269,8 @@ class FullImageDataset(data.Dataset): # LQ needs to go to a PIL image to perform the compression-artifact transformation. if self.opt['phase'] == 'train': - img_LQ = self.pil_augment(img_LQ) - lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2) + img_LQ = self.pil_augment([img_LQ])[0] + lq_fullsize_ref = self.pil_augment([lq_fullsize_ref], strength=.2)[0] img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()