Use lower LQ image size when it is being fed in

This commit is contained in:
James Betker 2020-09-06 17:26:32 -06:00
parent b1238d29cb
commit a5c2388368

View File

@ -90,8 +90,11 @@ class FullImageDataset(data.Dataset):
# - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a # - When extracting a square, the size of the square is randomly distributed [target_size, source_size] along a
# half-normal distribution, biasing towards the target_size. # half-normal distribution, biasing towards the target_size.
# - A biased normal distribution is also used to bias the tile selection towards the center of the source image. # - A biased normal distribution is also used to bias the tile selection towards the center of the source image.
def pull_tile(self, image): def pull_tile(self, image, lq=False):
target_sz = self.opt['min_tile_size'] if lq:
target_sz = self.opt['min_tile_size'] // self.opt['scale']
else:
target_sz = self.opt['min_tile_size']
h, w, _ = image.shape h, w, _ = image.shape
possible_sizes_above_target = h - target_sz possible_sizes_above_target = h - target_sz
square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.1)), 1.0)) square_size = int(target_sz + possible_sizes_above_target * min(np.abs(np.random.normal(scale=.1)), 1.0))
@ -153,6 +156,55 @@ class FullImageDataset(data.Dataset):
return img_LQ return img_LQ
def perform_random_hr_augment(self, image, aug_code=None, augmentations=1):
if aug_code is None:
aug_code = [random.randint(0, 10) for _ in range(augmentations)]
else:
assert augmentations == 1
aug_code = [aug_code]
if 0 in aug_code:
# Color quantization
pass
elif 1 in aug_code:
# Gaussian Blur (point or motion)
blur_magnitude = 3
blur_sig = int(random.randrange(0, int(blur_magnitude)))
image = cv2.GaussianBlur(image, (blur_magnitude, blur_magnitude), blur_sig)
elif 2 in aug_code:
# Median Blur
image = cv2.medianBlur(image, 3)
elif 3 in aug_code:
# Motion blur
image = self.motion_blur(image, random.randrange(1, 9), random.randint(0, 360))
elif 4 in aug_code:
# Smooth blur
image = cv2.blur(image, ksize=3)
elif 5 in aug_code:
# Block noise
pass
elif 6 in aug_code:
# Bicubic LR->HR
pass
elif 7 in aug_code:
# Linear compression distortion
pass
elif 8 in aug_code:
# Interlacing distortion
pass
elif 9 in aug_code:
# Chromatic aberration
pass
elif 10 in aug_code:
# Noise
pass
elif 11 in aug_code:
# JPEG compression
pass
elif 12 in aug_code:
# Lightening / saturation
pass
return image
def __getitem__(self, index): def __getitem__(self, index):
scale = self.opt['scale'] scale = self.opt['scale']
@ -177,7 +229,7 @@ class FullImageDataset(data.Dataset):
img_lq_full = util.read_img(None, LQ_path, None) img_lq_full = util.read_img(None, LQ_path, None)
img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0] img_lq_full = util.augment([img_lq_full], self.opt['use_flip'], self.opt['use_rot'])[0]
img_lq_full = self.get_square_image(img_lq_full) img_lq_full = self.get_square_image(img_lq_full)
img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full) img_LQ, lq_fullsize_ref, lq_mask, lq_center = self.pull_tile(img_lq_full, lq=True)
else: # down-sampling on-the-fly else: # down-sampling on-the-fly
# randomly scale during training # randomly scale during training
if self.opt['phase'] == 'train': if self.opt['phase'] == 'train':