Full image dataset operates on lists

This commit is contained in:
James Betker 2020-09-18 09:50:26 -06:00
parent 3cb2a9a9d3
commit 36ec32bf11

View File

@ -54,21 +54,24 @@ class FullImageDataset(data.Dataset):
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
# offset from center is chosen on a normal probability curve.
def get_square_image(self, image):
h, w, _ = image.shape
def get_square_image(self, images):
h, w, _ = images[0].shape
if h == w:
return image
return images
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
if h > w:
diff = h - w
center = diff // 2
top = int(center + offset * (center - 2))
return image[top:top+w, :, :]
else:
diff = w - h
center = diff // 2
left = int(center + offset * (center - 2))
return image[:, left:left+h, :]
res = []
for image in images:
if h > w:
diff = h - w
center = diff // 2
top = int(center + offset * (center - 2))
res.append(image[top:top+w, :, :])
else:
diff = w - h
center = diff // 2
left = int(center + offset * (center - 2))
res.append(image[:, left:left+h, :])
return res
def pick_along_range(self, sz, r, dev):
margin_sz = sz - r
@ -90,120 +93,88 @@ class FullImageDataset(data.Dataset):
# - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a
# half-normal distribution, biasing towards the target_size.
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
def pull_tile(self, image, lq=False):
def pull_tile(self, images, lq=False):
if lq:
target_sz = self.opt['min_tile_size'] // self.opt['scale']
else:
target_sz = self.opt['min_tile_size']
h, w, _ = image.shape
h, w, _ = images[0].shape
possible_sizes_above_target = h - target_sz
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0))
# Pick the left,top coords to draw the patch from
left = self.pick_along_range(w, square_size, .3)
top = self.pick_along_range(w, square_size, .3)
mask = np.zeros((h, w, 1), dtype=image.dtype)
mask[top:top+square_size, left:left+square_size] = 1
patch = image[top:top+square_size, left:left+square_size, :]
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
patches, ims, masks, centers = [], [], [], []
for image in images:
mask = np.zeros((h, w, 1), dtype=image.dtype)
mask[top:top+square_size, left:left+square_size] = 1
patch = image[top:top+square_size, left:left+square_size, :]
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
center = self.resize_point(center, (h, w), image.shape[:2])
patches.append(cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
ims.append(cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
masks.append(cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
centers.append(self.resize_point(center, (h, w), ims[-1].shape[:2]))
return patch, image, mask, center
return patches, ims, masks, centers
def augment_tile(self, img_GT, img_LQ, strength=1):
def augment_tile(self, imgs_GT, imgs_LQ, strength=1):
scale = self.opt['scale']
GT_size = self.opt['target_size']
H, W, _ = img_GT.shape
H, W, _ = imgs_GT[0].shape
assert H >= GT_size and W >= GT_size
LQ_size = GT_size // scale
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
# Establish random variables.
blur_det = random.randint(0, 100)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
blur_magnitude = max(1, int(blur_magnitude*strength))
blur_sig = int(random.randrange(0, int(blur_magnitude)))
blur_direction = random.randint(0, 360)
if self.opt['use_blurring']:
# Pick randomly between gaussian, motion, or no blur.
blur_det = random.randint(0, 100)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
blur_magnitude = max(1, int(blur_magnitude*strength))
if blur_det < 40:
blur_sig = int(random.randrange(0, int(blur_magnitude)))
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
elif blur_det < 70:
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
lqs, gts = [], []
for img_LQ, img_GT in zip(imgs_LQ, imgs_GT):
LQ_size = GT_size // scale
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
gts.append(cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR))
return img_GT, img_LQ
if self.opt['use_blurring']:
# Pick randomly between gaussian, motion, or no blur.
if blur_det < 40:
lqs.append(cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig))
elif blur_det < 70:
lqs.append(self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), blur_direction))
else:
lqs.append(img_LQ)
else:
lqs.append(img_LQ)
return gts, lqs
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
def pil_augment(self, img_LQ, strength=1):
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts'] and random.random() > .25:
sub_lo = 90 * strength
sub_hi = 30 * strength
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
def pil_augment(self, imgs_LQ, strength=1):
# Compute random variables
do_jpg = random.random()
sub_lo = 90 * strength
sub_hi = 30 * strength
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
ims_out = []
for img_LQ in imgs_LQ:
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts'] and do_jpg > .25:
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
return img_LQ
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
ims_out.append(img_LQ)
def perform_random_hr_augment(self, image, aug_code=None, augmentations=1):
if aug_code is None:
aug_code = [random.randint(0, 10) for _ in range(augmentations)]
else:
assert augmentations == 1
aug_code = [aug_code]
if 0 in aug_code:
# Color quantization
pass
elif 1 in aug_code:
# Gaussian Blur (point or motion)
blur_magnitude = 3
blur_sig = int(random.randrange(0, int(blur_magnitude)))
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
elif 2 in aug_code:
# Median Blur
image = cv2.medianBlur(image, 3)
elif 3 in aug_code:
# Motion blur
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
elif 4 in aug_code:
# Smooth blur
image = cv2.blur(image, ksize=3)
elif 5 in aug_code:
# Block noise
pass
elif 6 in aug_code:
# Bicubic LR->HR
pass
elif 7 in aug_code:
# Linear compression distortion
pass
elif 8 in aug_code:
# Interlacing distortion
pass
elif 9 in aug_code:
# Chromatic aberration
pass
elif 10 in aug_code:
# Noise
pass
elif 11 in aug_code:
# JPEG compression
pass
elif 12 in aug_code:
# Lightening / saturation
pass
return image
return ims_out
def __getitem__(self, index):
scale = self.opt['scale']
@ -215,8 +186,9 @@ class FullImageDataset(data.Dataset):
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
if self.opt['phase'] == 'train':
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_full = self.get_square_image(img_full)
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
img_full = self.get_square_image([img_full])[0]
imgs_GT, gt_fullsize_refs, gt_masks, gt_centers = self.pull_tile([img_full])
img_GT, gt_fullsize_ref, gt_mask, gt_center = imgs_GT[0], gt_fullsize_refs[0], gt_masks[0], gt_centers[0]
else:
img_GT, gt_fullsize_ref = img_full, img_full
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
@ -228,8 +200,9 @@ class FullImageDataset(data.Dataset):
LQ_path = self.get_lq_path(index)
img_lq_full = util.read_img(None, LQ_path, None)
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_lq_full = self.get_square_image(img_lq_full)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
img_lq_full = self.get_square_image([img_lq_full])[0]
imgs_LQ, lq_fullsize_refs, lq_masks, lq_centers = self.pull_tile([img_lq_full], lq=True)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = imgs_LQ[0], lq_fullsize_refs[0], lq_masks[0], lq_centers[0]
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
@ -274,8 +247,10 @@ class FullImageDataset(data.Dataset):
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
if self.opt['phase'] == 'train':
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
imgs_GT, imgs_LQ = self.augment_tile([img_GT], [img_LQ])
img_GT, img_LQ = imgs_GT[0], imgs_LQ[0]
gt_fullsize_refs, lq_fullsize_refs = self.augment_tile([gt_fullsize_ref], [lq_fullsize_ref], strength=.2)
gt_fullsize_ref, lq_fullsize_ref = gt_fullsize_refs[0], lq_fullsize_refs[0]
# Scale masks.
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
@ -294,8 +269,8 @@ class FullImageDataset(data.Dataset):
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
if self.opt['phase'] == 'train':
img_LQ = self.pil_augment(img_LQ)
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
img_LQ = self.pil_augment([img_LQ])[0]
lq_fullsize_ref = self.pil_augment([lq_fullsize_ref], strength=.2)[0]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()