Revert "Full image dataset operates on lists"

Going with an entirely new dataset instead..

This reverts commit 36ec32bf11.
This commit is contained in:
James Betker 2020-09-18 09:50:43 -06:00
parent 36ec32bf11
commit 4f75cf0f02

View File

@ -54,24 +54,21 @@ class FullImageDataset(data.Dataset):
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
# offset from center is chosen on a normal probability curve. # offset from center is chosen on a normal probability curve.
def get_square_image(self, images): def get_square_image(self, image):
h, w, _ = images[0].shape h, w, _ = image.shape
if h == w: if h == w:
return images return image
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0) offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
res = [] if h > w:
for image in images: diff = h - w
if h > w: center = diff // 2
diff = h - w top = int(center + offset * (center - 2))
center = diff // 2 return image[top:top+w, :, :]
top = int(center + offset * (center - 2)) else:
res.append(image[top:top+w, :, :]) diff = w - h
else: center = diff // 2
diff = w - h left = int(center + offset * (center - 2))
center = diff // 2 return image[:, left:left+h, :]
left = int(center + offset * (center - 2))
res.append(image[:, left:left+h, :])
return res
def pick_along_range(self, sz, r, dev): def pick_along_range(self, sz, r, dev):
margin_sz = sz - r margin_sz = sz - r
@ -93,88 +90,120 @@ class FullImageDataset(data.Dataset):
# - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a # - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a
# half-normal distribution, biasing towards the target_size. # half-normal distribution, biasing towards the target_size.
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image. # - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
def pull_tile(self, images, lq=False): def pull_tile(self, image, lq=False):
if lq: if lq:
target_sz = self.opt['min_tile_size'] // self.opt['scale'] target_sz = self.opt['min_tile_size'] // self.opt['scale']
else: else:
target_sz = self.opt['min_tile_size'] target_sz = self.opt['min_tile_size']
h, w, _ = images[0].shape h, w, _ = image.shape
possible_sizes_above_target = h - target_sz possible_sizes_above_target = h - target_sz
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0)) square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.17)), 1.0))
# Pick the left,top coords to draw the patch from # Pick the left,top coords to draw the patch from
left = self.pick_along_range(w, square_size, .3) left = self.pick_along_range(w, square_size, .3)
top = self.pick_along_range(w, square_size, .3) top = self.pick_along_range(w, square_size, .3)
patches, ims, masks, centers = [], [], [], [] mask = np.zeros((h, w, 1), dtype=image.dtype)
for image in images: mask[top:top+square_size, left:left+square_size] = 1
mask = np.zeros((h, w, 1), dtype=image.dtype) patch = image[top:top+square_size, left:left+square_size, :]
mask[top:top+square_size, left:left+square_size] = 1 center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
patch = image[top:top+square_size, left:left+square_size, :]
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
patches.append(cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) patch = cv2.resize(patch, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
ims.append(cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) image = cv2.resize(image, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
masks.append(cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)) mask = cv2.resize(mask, (target_sz, target_sz), interpolation=cv2.INTER_LINEAR)
centers.append(self.resize_point(center, (h, w), ims[-1].shape[:2])) center = self.resize_point(center, (h, w), image.shape[:2])
return patches, ims, masks, centers return patch, image, mask, center
def augment_tile(self, imgs_GT, imgs_LQ, strength=1): def augment_tile(self, img_GT, img_LQ, strength=1):
scale = self.opt['scale'] scale = self.opt['scale']
GT_size = self.opt['target_size'] GT_size = self.opt['target_size']
H, W, _ = imgs_GT[0].shape H, W, _ = img_GT.shape
assert H >= GT_size and W >= GT_size assert H >= GT_size and W >= GT_size
# Establish random variables. LQ_size = GT_size // scale
blur_det = random.randint(0, 100) img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR)
blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude'] img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
blur_magnitude = max(1, int(blur_magnitude*strength))
blur_sig = int(random.randrange(0, int(blur_magnitude)))
blur_direction = random.randint(0, 360)
lqs, gts = [], [] if self.opt['use_blurring']:
for img_LQ, img_GT in zip(imgs_LQ, imgs_GT): # Pick randomly between gaussian, motion, or no blur.
LQ_size = GT_size // scale blur_det = random.randint(0, 100)
img_LQ = cv2.resize(img_LQ, (LQ_size, LQ_size), interpolation=cv2.INTER_LINEAR) blur_magnitude = 3 if 'blur_magnitude' not in self.opt.keys() else self.opt['blur_magnitude']
gts.append(cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)) blur_magnitude = max(1, int(blur_magnitude*strength))
if blur_det < 40:
blur_sig = int(random.randrange(0, int(blur_magnitude)))
img_LQ = cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig)
elif blur_det < 70:
img_LQ = self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), random.randint(0, 360))
if self.opt['use_blurring']: return img_GT, img_LQ
# Pick randomly between gaussian, motion, or no blur.
if blur_det < 40:
lqs.append(cv2.GaussianBlur(img_LQ, (blur_magnitude, blur_magnitude), blur_sig))
elif blur_det < 70:
lqs.append(self.motion_blur(img_LQ, random.randrange(1, int(blur_magnitude) * 3), blur_direction))
else:
lqs.append(img_LQ)
else:
lqs.append(img_LQ)
return gts, lqs
# Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it. # Converts img_LQ to PIL and performs JPG compression corruptions and grayscale on the image, then returns it.
def pil_augment(self, imgs_LQ, strength=1): def pil_augment(self, img_LQ, strength=1):
# Compute random variables img_LQ = (img_LQ * 255).astype(np.uint8)
do_jpg = random.random() img_LQ = Image.fromarray(img_LQ)
sub_lo = 90 * strength if self.opt['use_compression_artifacts'] and random.random() > .25:
sub_hi = 30 * strength sub_lo = 90 * strength
qf = random.randrange(100 - sub_lo, 100 - sub_hi) sub_hi = 30 * strength
qf = random.randrange(100 - sub_lo, 100 - sub_hi)
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
ims_out = [] if 'grayscale' in self.opt.keys() and self.opt['grayscale']:
for img_LQ in imgs_LQ: img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
img_LQ = (img_LQ * 255).astype(np.uint8)
img_LQ = Image.fromarray(img_LQ)
if self.opt['use_compression_artifacts'] and do_jpg > .25:
corruption_buffer = BytesIO()
img_LQ.save(corruption_buffer, "JPEG", quality=qf, optimice=True)
corruption_buffer.seek(0)
img_LQ = Image.open(corruption_buffer)
if 'grayscale' in self.opt.keys() and self.opt['grayscale']: return img_LQ
img_LQ = ImageOps.grayscale(img_LQ).convert('RGB')
ims_out.append(img_LQ)
return ims_out def perform_random_hr_augment(self, image, aug_code=None, augmentations=1):
if aug_code is None:
aug_code = [random.randint(0, 10) for _ in range(augmentations)]
else:
assert augmentations == 1
aug_code = [aug_code]
if 0 in aug_code:
# Color quantization
pass
elif 1 in aug_code:
# Gaussian Blur (point or motion)
blur_magnitude = 3
blur_sig = int(random.randrange(0, int(blur_magnitude)))
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
elif 2 in aug_code:
# Median Blur
image = cv2.medianBlur(image, 3)
elif 3 in aug_code:
# Motion blur
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
elif 4 in aug_code:
# Smooth blur
image = cv2.blur(image, ksize=3)
elif 5 in aug_code:
# Block noise
pass
elif 6 in aug_code:
# Bicubic LR->HR
pass
elif 7 in aug_code:
# Linear compression distortion
pass
elif 8 in aug_code:
# Interlacing distortion
pass
elif 9 in aug_code:
# Chromatic aberration
pass
elif 10 in aug_code:
# Noise
pass
elif 11 in aug_code:
# JPEG compression
pass
elif 12 in aug_code:
# Lightening / saturation
pass
return image
def __getitem__(self, index): def __getitem__(self, index):
scale = self.opt['scale'] scale = self.opt['scale']
@ -186,9 +215,8 @@ class FullImageDataset(data.Dataset):
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0] img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0]
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':
img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0] img_full = util.augment([img_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_full = self.get_square_image([img_full])[0] img_full = self.get_square_image(img_full)
imgs_GT, gt_fullsize_refs, gt_masks, gt_centers = self.pull_tile([img_full]) img_GT, gt_fullsize_ref, gt_mask, gt_center = self.pull_tile(img_full)
img_GT, gt_fullsize_ref, gt_mask, gt_center = imgs_GT[0], gt_fullsize_refs[0], gt_masks[0], gt_centers[0]
else: else:
img_GT, gt_fullsize_ref = img_full, img_full img_GT, gt_fullsize_ref = img_full, img_full
gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype) gt_mask = np.ones(img_full.shape[:2], dtype=gt_fullsize_ref.dtype)
@ -200,9 +228,8 @@ class FullImageDataset(data.Dataset):
LQ_path = self.get_lq_path(index) LQ_path = self.get_lq_path(index)
img_lq_full = util.read_img(None, LQ_path, None) img_lq_full = util.read_img(None, LQ_path, None)
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_lq_full = self.get_square_image([img_lq_full])[0] img_lq_full = self.get_square_image(img_lq_full)
imgs_LQ, lq_fullsize_refs, lq_masks, lq_centers = self.pull_tile([img_lq_full], lq=True) img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = imgs_LQ[0], lq_fullsize_refs[0], lq_masks[0], lq_centers[0]
else: # down-sampling on-the-fly else: # down-sampling on-the-fly
# randomly scale during training # randomly scale during training
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':
@ -247,10 +274,8 @@ class FullImageDataset(data.Dataset):
gt_fullsize_ref = gt_fullsize_ref[:h, :w, :] gt_fullsize_ref = gt_fullsize_ref[:h, :w, :]
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':
imgs_GT, imgs_LQ = self.augment_tile([img_GT], [img_LQ]) img_GT, img_LQ = self.augment_tile(img_GT, img_LQ)
img_GT, img_LQ = imgs_GT[0], imgs_LQ[0] gt_fullsize_ref, lq_fullsize_ref = self.augment_tile(gt_fullsize_ref, lq_fullsize_ref, strength=.2)
gt_fullsize_refs, lq_fullsize_refs = self.augment_tile([gt_fullsize_ref], [lq_fullsize_ref], strength=.2)
gt_fullsize_ref, lq_fullsize_ref = gt_fullsize_refs[0], lq_fullsize_refs[0]
# Scale masks. # Scale masks.
lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR) lq_mask = cv2.resize(lq_mask, (lq_fullsize_ref.shape[1], lq_fullsize_ref.shape[0]), interpolation=cv2.INTER_LINEAR)
@ -269,8 +294,8 @@ class FullImageDataset(data.Dataset):
# LQ needs to go to a PIL image to perform the compression-artifact transformation. # LQ needs to go to a PIL image to perform the compression-artifact transformation.
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':
img_LQ = self.pil_augment([img_LQ])[0] img_LQ = self.pil_augment(img_LQ)
lq_fullsize_ref = self.pil_augment([lq_fullsize_ref], strength=.2)[0] lq_fullsize_ref = self.pil_augment(lq_fullsize_ref, strength=.2)
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float() gt_fullsize_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(gt_fullsize_ref, (2, 0, 1)))).float()