diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index 47b9e103..f93bcb06 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -13,6 +13,24 @@ import torchvision.transforms.functional as F from data.image_corruptor import ImageCorruptor +# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping +# offset from center is chosen on a normal probability curve. +def get_square_image(image): + h, w, _ = image.shape + if h == w: + return image + offset = max(min(np.random.normal(scale=.3), 1.0), -1.0) + if h > w: + diff = h - w + center = diff // 2 + top = max(int(center + offset * (center - 2)), 0) + return image[top:top + w, :, :] + else: + diff = w - h + center = diff // 2 + left = max(int(center + offset * (center - 2)), 0) + return image[:, left:left + h, :] + class MultiScaleDataset(data.Dataset): def __init__(self, opt): super(MultiScaleDataset, self).__init__() @@ -25,23 +43,6 @@ class MultiScaleDataset(data.Dataset): self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']]) self.corruptor = ImageCorruptor(opt) - # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping - # offset from center is chosen on a normal probability curve. - def get_square_image(self, image): - h, w, _ = image.shape - if h == w: - return image - offset = max(min(np.random.normal(scale=.3), 1.0), -1.0) - if h > w: - diff = h - w - center = diff // 2 - top = int(center + offset * (center - 2)) - return image[top:top+w, :, :] - else: - diff = w - h - center = diff // 2 - left = int(center + offset * (center - 2)) - return image[:, left:left+h, :] def recursively_extract_patches(self, input_img, result_list, depth): if depth >= self.num_scales: @@ -62,7 +63,7 @@ class MultiScaleDataset(data.Dataset): loaded_img = util.read_img(None, full_path, None) img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0] img_full2 = util.augment([img_full1], True, True)[0] - img_full3 = self.get_square_image(img_full2) + img_full3 = get_square_image(img_full2) # This error crops up from time to time. I suspect an issue with util.read_img. if img_full3.shape[0] == 0 or img_full3.shape[1] == 0: print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))