Fix bug with multiscale_dataset

This commit is contained in:
James Betker 2020-10-31 20:54:41 -06:00
parent eb7df63592
commit ea8c20c0e2

View File

@ -13,6 +13,24 @@ import torchvision.transforms.functional as F
from data.image_corruptor import ImageCorruptor
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
# offset from center is chosen on a normal probability curve.
def get_square_image(image):
h, w, _ = image.shape
if h == w:
return image
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
if h > w:
diff = h - w
center = diff // 2
top = max(int(center + offset * (center - 2)), 0)
return image[top:top + w, :, :]
else:
diff = w - h
center = diff // 2
left = max(int(center + offset * (center - 2)), 0)
return image[:, left:left + h, :]
class MultiScaleDataset(data.Dataset):
def __init__(self, opt):
super(MultiScaleDataset, self).__init__()
@ -25,23 +43,6 @@ class MultiScaleDataset(data.Dataset):
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']])
self.corruptor = ImageCorruptor(opt)
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
# offset from center is chosen on a normal probability curve.
def get_square_image(self, image):
h, w, _ = image.shape
if h == w:
return image
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
if h > w:
diff = h - w
center = diff // 2
top = int(center + offset * (center - 2))
return image[top:top+w, :, :]
else:
diff = w - h
center = diff // 2
left = int(center + offset * (center - 2))
return image[:, left:left+h, :]
def recursively_extract_patches(self, input_img, result_list, depth):
if depth >= self.num_scales:
@ -62,7 +63,7 @@ class MultiScaleDataset(data.Dataset):
loaded_img = util.read_img(None, full_path, None)
img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0]
img_full2 = util.augment([img_full1], True, True)[0]
img_full3 = self.get_square_image(img_full2)
img_full3 = get_square_image(img_full2)
# This error crops up from time to time. I suspect an issue with util.read_img.
if img_full3.shape[0] == 0 or img_full3.shape[1] == 0:
print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))