Revert "Full image dataset operates on lists"
Going with an entirely new dataset instead..
This reverts commit 36ec32bf11
.
This commit is contained in:
parent
36ec32bf11
commit
4f75cf0f02
|
@ -54,24 +54,21 @@ class FullImageDataset(data.Dataset):
|
||||||
|
|
||||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||||
# offset from center is chosen on a normal probability curve.
|
# offset from center is chosen on a normal probability curve.
|
||||||
def get_square_image(self, images):
|
def get_square_image(self, image):
|
||||||
h, w, _ = images[0].shape
|
h, w, _ = image.shape
|
||||||
if h == w:
|
if h == w:
|
||||||
return images
|
return image
|
||||||
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
||||||
res = []
|
if h > w:
|
||||||
for image in images:
|
diff = h - w
|
||||||
if h > w:
|
center = diff // 2
|
||||||
diff = h - w
|
top = int(center + offset * (center - 2))
|
||||||
center = diff // 2
|
return image[top:top+w, :, :]
|
||||||
top = int(center + offset * (center - 2))
|
else:
|
||||||
res.append(image[top:top+w, :, :])
|
diff = w - h
|
||||||
else:
|
center = diff // 2
|
||||||
diff = w - h
|
left = int(center + offset * (center - 2))
|
||||||
center = diff // 2
|
return image[:, left:left+h, :]
|
||||||
left = int(center + offset * (center - 2))
|
|
||||||
res.append(image[:, left:left+h, :])
|
|
||||||
return res
|
|
||||||
|
|
||||||
def pick_along_range(self, sz, r, dev):
|
def pick_along_range(self, sz, r, dev):
|
||||||
margin_sz = sz - r
|
margin_sz = sz - r
|
||||||
|
@ -93,88 +90,120 @@ class FullImageDataset(data.Dataset):
|
||||||
# - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a
|
# - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a
|
||||||
# half-normal distribution, biasing towards the target_size.
|
# half-normal distribution, biasing towards the target_size.
|
||||||
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
|
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
|
||||||
def pull_tile(self, images, lq=False):
|
def pull_tile(self, image, lq=False):
|
||||||
if lq:
|
if lq:
|
||||||
target_sz = self.opt['min_tile_size'] // self.opt['scale']
|
target_sz = self.opt['min_tile_size'] // self.opt['scale']
|
||||||
else:
|
else:
|
||||||
target_sz = self.opt['min_tile_size']
|
target_sz = self.opt['min_tile_size']
|
||||||
h, w, _ = images[0].shape
|
h, w, _ = image.shape
|
||||||
possible_sizes_above_target = h - target_sz
|
possible_sizes_above_target = h - target_sz
|
||||||
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0))
|
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0))
|
||||||
# Pick the left,top coords to draw the patch from
|
# Pick the left,top coords to draw the patch from
|
||||||
left = self.pick_along_range(w, square_size, .3)
|
left = self.pick_along_range(w, square_size, .3)
|
||||||
top = self.pick_along_range(w, square_size, .3)
|
top = self.pick_along_range(w, square_size, .3)
|
||||||
|
|
||||||
patches, ims, masks, centers = [], [], [], []
|
mask = np.zeros((h, w, 1), dtype=image.dtype)
|
||||||
for image in images:
|
mask[top:top+square_size, left:left+square_size] = 1
|
||||||
mask = np.zeros((h, w, 1), dtype=image.dtype)
|
patch = image[top:top+square_size, left:left+square_size, :]
|
||||||
mask[top:top+square_size, left:left+square_size] = 1
|
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
||||||
patch = image[top:top+square_size, left:left+square_size, :]
|
|
||||||
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
|
||||||
|
|
||||||
patches.append(cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
|
patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||||
ims.append(cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
|
image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||||
masks.append(cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR))
|
mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
|
||||||
centers.append(self.resize_point(center, (h, w), ims[-1].shape[:2]))
|
center = self.resize_point(center, (h, w), image.shape[:2])
|
||||||
|
|
||||||
return patches, ims, masks, centers
|
return patch, image, mask, center
|
||||||
|
|
||||||
def augment_tile(self, imgs_GT, imgs_LQ, strength=1):
|
def augment_tile(self, img_GT, img_LQ, strength=1):
|
||||||
scale = self.opt['scale']
|
scale = self.opt['scale']
|
||||||
GT_size = self.opt['target_size']
|
GT_size = self.opt['target_size']
|
||||||
|
|
||||||
H, W, _ = imgs_GT[0].shape
|
H, W, _ = img_GT.shape
|
||||||
assert H >= GT_size and W >= GT_size
|
assert H >= GT_size and W >= GT_size
|
||||||
|
|
||||||
# Establish random variables.
|
LQ_size = GT_size // scale
|
||||||
blur_det = random.randint(0, 100)
|
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
||||||
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
|
||||||
blur_magnitude = max(1, int(blur_magnitude*strength))
|
|
||||||
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
|
||||||
blur_direction = random.randint(0, 360)
|
|
||||||
|
|
||||||
lqs, gts = [], []
|
if self.opt['use_blurring']:
|
||||||
for img_LQ, img_GT in zip(imgs_LQ, imgs_GT):
|
# Pick randomly between gaussian, motion, or no blur.
|
||||||
LQ_size = GT_size // scale
|
blur_det = random.randint(0, 100)
|
||||||
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
|
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
|
||||||
gts.append(cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR))
|
blur_magnitude = max(1, int(blur_magnitude*strength))
|
||||||
|
if blur_det < 40:
|
||||||
|
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||||||
|
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
|
||||||
|
elif blur_det < 70:
|
||||||
|
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
|
||||||
|
|
||||||
if self.opt['use_blurring']:
|
return img_GT, img_LQ
|
||||||
# Pick randomly between gaussian, motion, or no blur.
|
|
||||||
if blur_det < 40:
|
|
||||||
lqs.append(cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig))
|
|
||||||
elif blur_det < 70:
|
|
||||||
lqs.append(self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), blur_direction))
|
|
||||||
else:
|
|
||||||
lqs.append(img_LQ)
|
|
||||||
else:
|
|
||||||
lqs.append(img_LQ)
|
|
||||||
|
|
||||||
return gts, lqs
|
|
||||||
|
|
||||||
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
|
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
|
||||||
def pil_augment(self, imgs_LQ, strength=1):
|
def pil_augment(self, img_LQ, strength=1):
|
||||||
# Compute random variables
|
img_LQ = (img_LQ * 255).astype(np.uint8)
|
||||||
do_jpg = random.random()
|
img_LQ = Image.fromarray(img_LQ)
|
||||||
sub_lo = 90 * strength
|
if self.opt['use_compression_artifacts'] and random.random() > .25:
|
||||||
sub_hi = 30 * strength
|
sub_lo = 90 * strength
|
||||||
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
|
sub_hi = 30 * strength
|
||||||
|
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
|
||||||
|
corruption_buffer = BytesIO()
|
||||||
|
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
||||||
|
corruption_buffer.seek(0)
|
||||||
|
img_LQ = Image.open(corruption_buffer)
|
||||||
|
|
||||||
ims_out = []
|
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
||||||
for img_LQ in imgs_LQ:
|
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
||||||
img_LQ = (img_LQ * 255).astype(np.uint8)
|
|
||||||
img_LQ = Image.fromarray(img_LQ)
|
|
||||||
if self.opt['use_compression_artifacts'] and do_jpg > .25:
|
|
||||||
corruption_buffer = BytesIO()
|
|
||||||
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
|
|
||||||
corruption_buffer.seek(0)
|
|
||||||
img_LQ = Image.open(corruption_buffer)
|
|
||||||
|
|
||||||
if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
|
return img_LQ
|
||||||
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
|
|
||||||
ims_out.append(img_LQ)
|
|
||||||
|
|
||||||
return ims_out
|
def perform_random_hr_augment(self, image, aug_code=None, augmentations=1):
|
||||||
|
if aug_code is None:
|
||||||
|
aug_code = [random.randint(0, 10) for _ in range(augmentations)]
|
||||||
|
else:
|
||||||
|
assert augmentations == 1
|
||||||
|
aug_code = [aug_code]
|
||||||
|
if 0 in aug_code:
|
||||||
|
# Color quantization
|
||||||
|
pass
|
||||||
|
elif 1 in aug_code:
|
||||||
|
# Gaussian Blur (point or motion)
|
||||||
|
blur_magnitude = 3
|
||||||
|
blur_sig = int(random.randrange(0, int(blur_magnitude)))
|
||||||
|
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
|
||||||
|
elif 2 in aug_code:
|
||||||
|
# Median Blur
|
||||||
|
image = cv2.medianBlur(image, 3)
|
||||||
|
elif 3 in aug_code:
|
||||||
|
# Motion blur
|
||||||
|
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
|
||||||
|
elif 4 in aug_code:
|
||||||
|
# Smooth blur
|
||||||
|
image = cv2.blur(image, ksize=3)
|
||||||
|
elif 5 in aug_code:
|
||||||
|
# Block noise
|
||||||
|
pass
|
||||||
|
elif 6 in aug_code:
|
||||||
|
# Bicubic LR->HR
|
||||||
|
pass
|
||||||
|
elif 7 in aug_code:
|
||||||
|
# Linear compression distortion
|
||||||
|
pass
|
||||||
|
elif 8 in aug_code:
|
||||||
|
# Interlacing distortion
|
||||||
|
pass
|
||||||
|
elif 9 in aug_code:
|
||||||
|
# Chromatic aberration
|
||||||
|
pass
|
||||||
|
elif 10 in aug_code:
|
||||||
|
# Noise
|
||||||
|
pass
|
||||||
|
elif 11 in aug_code:
|
||||||
|
# JPEG compression
|
||||||
|
pass
|
||||||
|
elif 12 in aug_code:
|
||||||
|
# Lightening / saturation
|
||||||
|
pass
|
||||||
|
return image
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
scale = self.opt['scale']
|
scale = self.opt['scale']
|
||||||
|
@ -186,9 +215,8 @@ class FullImageDataset(data.Dataset):
|
||||||
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
|
||||||
if self.opt['phase'] == 'train':
|
if self.opt['phase'] == 'train':
|
||||||
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||||
img_full = self.get_square_image([img_full])[0]
|
img_full = self.get_square_image(img_full)
|
||||||
imgs_GT, gt_fullsize_refs, gt_masks, gt_centers = self.pull_tile([img_full])
|
img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
|
||||||
img_GT, gt_fullsize_ref, gt_mask, gt_center = imgs_GT[0], gt_fullsize_refs[0], gt_masks[0], gt_centers[0]
|
|
||||||
else:
|
else:
|
||||||
img_GT, gt_fullsize_ref = img_full, img_full
|
img_GT, gt_fullsize_ref = img_full, img_full
|
||||||
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
|
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
|
||||||
|
@ -200,9 +228,8 @@ class FullImageDataset(data.Dataset):
|
||||||
LQ_path = self.get_lq_path(index)
|
LQ_path = self.get_lq_path(index)
|
||||||
img_lq_full = util.read_img(None, LQ_path, None)
|
img_lq_full = util.read_img(None, LQ_path, None)
|
||||||
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
|
||||||
img_lq_full = self.get_square_image([img_lq_full])[0]
|
img_lq_full = self.get_square_image(img_lq_full)
|
||||||
imgs_LQ, lq_fullsize_refs, lq_masks, lq_centers = self.pull_tile([img_lq_full], lq=True)
|
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
|
||||||
img_LQ, lq_fullsize_ref, lq_mask, lq_center = imgs_LQ[0], lq_fullsize_refs[0], lq_masks[0], lq_centers[0]
|
|
||||||
else: # down-sampling on-the-fly
|
else: # down-sampling on-the-fly
|
||||||
# randomly scale during training
|
# randomly scale during training
|
||||||
if self.opt['phase'] == 'train':
|
if self.opt['phase'] == 'train':
|
||||||
|
@ -247,10 +274,8 @@ class FullImageDataset(data.Dataset):
|
||||||
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
|
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
|
||||||
|
|
||||||
if self.opt['phase'] == 'train':
|
if self.opt['phase'] == 'train':
|
||||||
imgs_GT, imgs_LQ = self.augment_tile([img_GT], [img_LQ])
|
img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
|
||||||
img_GT, img_LQ = imgs_GT[0], imgs_LQ[0]
|
gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
|
||||||
gt_fullsize_refs, lq_fullsize_refs = self.augment_tile([gt_fullsize_ref], [lq_fullsize_ref], strength=.2)
|
|
||||||
gt_fullsize_ref, lq_fullsize_ref = gt_fullsize_refs[0], lq_fullsize_refs[0]
|
|
||||||
|
|
||||||
# Scale masks.
|
# Scale masks.
|
||||||
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||||
|
@ -269,8 +294,8 @@ class FullImageDataset(data.Dataset):
|
||||||
|
|
||||||
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
# LQ needs to go to a PIL image to perform the compression-artifact transformation.
|
||||||
if self.opt['phase'] == 'train':
|
if self.opt['phase'] == 'train':
|
||||||
img_LQ = self.pil_augment([img_LQ])[0]
|
img_LQ = self.pil_augment(img_LQ)
|
||||||
lq_fullsize_ref = self.pil_augment([lq_fullsize_ref], strength=.2)[0]
|
lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
|
||||||
|
|
||||||
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
|
||||||
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user