Full image dataset operates on lists
This commit is contained in:
parent
3cb2a9a9d3
commit
36ec32bf11
|
@ -54,21 +54,24 @@ class FullImageDataset(data.Dataset):
|
|||
|
||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||
# offset from center is chosen on a normal probability curve.
|
||||
def get_square_image(self, image):
|
||||
h, w, _ = image.shape
|
||||
def get_square_image(self, images):
|
||||
h, w, _ = images[0].shape
|
||||
if h == w:
|
||||
return image
|
||||
return images
|
||||
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
||||
res = []
|
||||
for image in images:
|
||||
if h > w:
|
||||
diff = h - w
|
||||
center = diff // 2
|
||||
top = int(center + offset * (center - 2))
|
||||
return image[top:top+w, :, :]
|
||||
res.append(image[top:top+w, :, :])
|
||||
else:
|
||||
diff = w - h
|
||||
center = diff // 2
|
||||
left = int(center + offset * (center - 2))
|
||||
return image[:, left:left+h, :]
|
||||
res.append(image[:, left:left+h, :])
|
||||
return res
|
||||
|
||||
def pick_along_range(self, sz, r, dev):
|
||||
margin_sz = sz - r
|
||||
|
@ -90,62 +93,78 @@ class FullImageDataset(data.Dataset):
|
|||
# - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a
|
||||
# half-normal distribution, biasing towards the target_size.
|
||||
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
|
||||
def pull_tile(self, image, lq=False):
|
||||
def pull_tile(self, images, lq=False):
|
||||
if lq:
|
||||
target_sz = self.opt['min_tile_size'] // self.opt['scale']
|
||||
else:
|
||||
target_sz = self.opt['min_tile_size']
|
||||
h, w, _ = image.shape
|
||||
h, w, _ = images[0].shape
|
||||
possible_sizes_above_target = h - target_sz
|
||||
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0))
|
||||
# Pick the left,top coords to draw the patch from
|
||||
left = self.pick_along_range(w, square_size, .3)
|
||||
top = self.pick_along_range(w, square_size, .3)
|
||||
|
||||
patches, ims, masks, centers = [], [], [], []
|
||||
for image in images:
|
||||
mask = np.zeros((h, w, 1), dtype=image.dtype)
|
||||
mask[top:top+square_size, left:left+square_size] = 1
|
||||
patch = image[top:top+square_size, left:left+square_size, :]
|
||||
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
||||
|
||||
patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||
center = self.resize_point(center, (h, w), image.shape[:2])
|
||||
patches.append(cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
|
||||
ims.append(cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
|
||||
masks.append(cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
|
||||
centers.append(self.resize_point(center, (h, w), ims[-1].shape[:2]))
|
||||
|
||||
return patch, image, mask, center
|
||||
return patches, ims, masks, centers
|
||||
|
||||
def augment_tile(self, img_GT, img_LQ, strength=1):
|
||||
def augment_tile(self, imgs_GT, imgs_LQ, strength=1):
|
||||
scale = self.opt['scale']
|
||||
GT_size = self.opt['target_size']
|
||||
|
||||
H, W, _ = img_GT.shape
|
||||
H, W, _ = imgs_GT[0].shape
|
||||
assert H >= GT_size and W >= GT_size
|
||||
|
||||
LQ_size = GT_size // scale
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if self.opt['use_blurring']:
|
||||
# Pick randomly between gaussian, motion, or no blur.
|
||||
# Establish random variables.
|
||||
blur_det = random.randint(0, 100)
|
||||
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
||||
blur_magnitude = max(1, int(blur_magnitude*strength))
|
||||
if blur_det < 40:
|
||||
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||||
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
elif blur_det < 70:
|
||||
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
|
||||
blur_direction = random.randint(0, 360)
|
||||
|
||||
return img_GT, img_LQ
|
||||
lqs, gts = [], []
|
||||
for img_LQ, img_GT in zip(imgs_LQ, imgs_GT):
|
||||
LQ_size = GT_size // scale
|
||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||
gts.append(cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR))
|
||||
|
||||
if self.opt['use_blurring']:
|
||||
# Pick randomly between gaussian, motion, or no blur.
|
||||
if blur_det < 40:
|
||||
lqs.append(cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig))
|
||||
elif blur_det < 70:
|
||||
lqs.append(self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), blur_direction))
|
||||
else:
|
||||
lqs.append(img_LQ)
|
||||
else:
|
||||
lqs.append(img_LQ)
|
||||
|
||||
return gts, lqs
|
||||
|
||||
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
|
||||
def pil_augment(self, img_LQ, strength=1):
|
||||
img_LQ = (img_LQ * 255).astype(np.uint8)
|
||||
img_LQ = Image.fromarray(img_LQ)
|
||||
if self.opt['use_compression_artifacts'] and random.random() > .25:
|
||||
def pil_augment(self, imgs_LQ, strength=1):
|
||||
# Compute random variables
|
||||
do_jpg = random.random()
|
||||
sub_lo = 90 * strength
|
||||
sub_hi = 30 * strength
|
||||
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
|
||||
|
||||
ims_out = []
|
||||
for img_LQ in imgs_LQ:
|
||||
img_LQ = (img_LQ * 255).astype(np.uint8)
|
||||
img_LQ = Image.fromarray(img_LQ)
|
||||
if self.opt['use_compression_artifacts'] and do_jpg > .25:
|
||||
corruption_buffer = BytesIO()
|
||||
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||||
corruption_buffer.seek(0)
|
||||
|
@ -153,57 +172,9 @@ class FullImageDataset(data.Dataset):
|
|||
|
||||
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
||||
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
||||
ims_out.append(img_LQ)
|
||||
|
||||
return img_LQ
|
||||
|
||||
def perform_random_hr_augment(self, image, aug_code=None, augmentations=1):
|
||||
if aug_code is None:
|
||||
aug_code = [random.randint(0, 10) for _ in range(augmentations)]
|
||||
else:
|
||||
assert augmentations == 1
|
||||
aug_code = [aug_code]
|
||||
if 0 in aug_code:
|
||||
# Color quantization
|
||||
pass
|
||||
elif 1 in aug_code:
|
||||
# Gaussian Blur (point or motion)
|
||||
blur_magnitude = 3
|
||||
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||||
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
|
||||
elif 2 in aug_code:
|
||||
# Median Blur
|
||||
image = cv2.medianBlur(image, 3)
|
||||
elif 3 in aug_code:
|
||||
# Motion blur
|
||||
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
|
||||
elif 4 in aug_code:
|
||||
# Smooth blur
|
||||
image = cv2.blur(image, ksize=3)
|
||||
elif 5 in aug_code:
|
||||
# Block noise
|
||||
pass
|
||||
elif 6 in aug_code:
|
||||
# Bicubic LR->HR
|
||||
pass
|
||||
elif 7 in aug_code:
|
||||
# Linear compression distortion
|
||||
pass
|
||||
elif 8 in aug_code:
|
||||
# Interlacing distortion
|
||||
pass
|
||||
elif 9 in aug_code:
|
||||
# Chromatic aberration
|
||||
pass
|
||||
elif 10 in aug_code:
|
||||
# Noise
|
||||
pass
|
||||
elif 11 in aug_code:
|
||||
# JPEG compression
|
||||
pass
|
||||
elif 12 in aug_code:
|
||||
# Lightening / saturation
|
||||
pass
|
||||
return image
|
||||
return ims_out
|
||||
|
||||
def __getitem__(self, index):
|
||||
scale = self.opt['scale']
|
||||
|
@ -215,8 +186,9 @@ class FullImageDataset(data.Dataset):
|
|||
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
||||
if self.opt['phase'] == 'train':
|
||||
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_full = self.get_square_image(img_full)
|
||||
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
|
||||
img_full = self.get_square_image([img_full])[0]
|
||||
imgs_GT, gt_fullsize_refs, gt_masks, gt_centers = self.pull_tile([img_full])
|
||||
img_GT, gt_fullsize_ref, gt_mask, gt_center = imgs_GT[0], gt_fullsize_refs[0], gt_masks[0], gt_centers[0]
|
||||
else:
|
||||
img_GT, gt_fullsize_ref = img_full, img_full
|
||||
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
|
||||
|
@ -228,8 +200,9 @@ class FullImageDataset(data.Dataset):
|
|||
LQ_path = self.get_lq_path(index)
|
||||
img_lq_full = util.read_img(None, LQ_path, None)
|
||||
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||
img_lq_full = self.get_square_image(img_lq_full)
|
||||
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
|
||||
img_lq_full = self.get_square_image([img_lq_full])[0]
|
||||
imgs_LQ, lq_fullsize_refs, lq_masks, lq_centers = self.pull_tile([img_lq_full], lq=True)
|
||||
img_LQ, lq_fullsize_ref, lq_mask, lq_center = imgs_LQ[0], lq_fullsize_refs[0], lq_masks[0], lq_centers[0]
|
||||
else: # down-sampling on-the-fly
|
||||
# randomly scale during training
|
||||
if self.opt['phase'] == 'train':
|
||||
|
@ -274,8 +247,10 @@ class FullImageDataset(data.Dataset):
|
|||
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
|
||||
|
||||
if self.opt['phase'] == 'train':
|
||||
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
|
||||
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
||||
imgs_GT, imgs_LQ = self.augment_tile([img_GT], [img_LQ])
|
||||
img_GT, img_LQ = imgs_GT[0], imgs_LQ[0]
|
||||
gt_fullsize_refs, lq_fullsize_refs = self.augment_tile([gt_fullsize_ref], [lq_fullsize_ref], strength=.2)
|
||||
gt_fullsize_ref, lq_fullsize_ref = gt_fullsize_refs[0], lq_fullsize_refs[0]
|
||||
|
||||
# Scale masks.
|
||||
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
|
@ -294,8 +269,8 @@ class FullImageDataset(data.Dataset):
|
|||
|
||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||
if self.opt['phase'] == 'train':
|
||||
img_LQ = self.pil_augment(img_LQ)
|
||||
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||||
img_LQ = self.pil_augment([img_LQ])[0]
|
||||
lq_fullsize_ref = self.pil_augment([lq_fullsize_ref], strength=.2)[0]
|
||||
|
||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
||||
|
|
Loading…
Reference in New Issue
Block a user