Fix bug with multiscale_dataset
This commit is contained in:
parent
eb7df63592
commit
ea8c20c0e2
|
@ -13,6 +13,24 @@ import torchvision.transforms.functional as F
|
||||||
from data.image_corruptor import ImageCorruptor
|
from data.image_corruptor import ImageCorruptor
|
||||||
|
|
||||||
|
|
||||||
|
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||||
|
# offset from center is chosen on a normal probability curve.
|
||||||
|
def get_square_image(image):
|
||||||
|
h, w, _ = image.shape
|
||||||
|
if h == w:
|
||||||
|
return image
|
||||||
|
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
||||||
|
if h > w:
|
||||||
|
diff = h - w
|
||||||
|
center = diff // 2
|
||||||
|
top = max(int(center + offset * (center - 2)), 0)
|
||||||
|
return image[top:top + w, :, :]
|
||||||
|
else:
|
||||||
|
diff = w - h
|
||||||
|
center = diff // 2
|
||||||
|
left = max(int(center + offset * (center - 2)), 0)
|
||||||
|
return image[:, left:left + h, :]
|
||||||
|
|
||||||
class MultiScaleDataset(data.Dataset):
|
class MultiScaleDataset(data.Dataset):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(MultiScaleDataset, self).__init__()
|
super(MultiScaleDataset, self).__init__()
|
||||||
|
@ -25,23 +43,6 @@ class MultiScaleDataset(data.Dataset):
|
||||||
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']])
|
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']])
|
||||||
self.corruptor = ImageCorruptor(opt)
|
self.corruptor = ImageCorruptor(opt)
|
||||||
|
|
||||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
|
||||||
# offset from center is chosen on a normal probability curve.
|
|
||||||
def get_square_image(self, image):
|
|
||||||
h, w, _ = image.shape
|
|
||||||
if h == w:
|
|
||||||
return image
|
|
||||||
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
|
||||||
if h > w:
|
|
||||||
diff = h - w
|
|
||||||
center = diff // 2
|
|
||||||
top = int(center + offset * (center - 2))
|
|
||||||
return image[top:top+w, :, :]
|
|
||||||
else:
|
|
||||||
diff = w - h
|
|
||||||
center = diff // 2
|
|
||||||
left = int(center + offset * (center - 2))
|
|
||||||
return image[:, left:left+h, :]
|
|
||||||
|
|
||||||
def recursively_extract_patches(self, input_img, result_list, depth):
|
def recursively_extract_patches(self, input_img, result_list, depth):
|
||||||
if depth >= self.num_scales:
|
if depth >= self.num_scales:
|
||||||
|
@ -62,7 +63,7 @@ class MultiScaleDataset(data.Dataset):
|
||||||
loaded_img = util.read_img(None, full_path, None)
|
loaded_img = util.read_img(None, full_path, None)
|
||||||
img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0]
|
img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0]
|
||||||
img_full2 = util.augment([img_full1], True, True)[0]
|
img_full2 = util.augment([img_full1], True, True)[0]
|
||||||
img_full3 = self.get_square_image(img_full2)
|
img_full3 = get_square_image(img_full2)
|
||||||
# This error crops up from time to time. I suspect an issue with util.read_img.
|
# This error crops up from time to time. I suspect an issue with util.read_img.
|
||||||
if img_full3.shape[0] == 0 or img_full3.shape[1] == 0:
|
if img_full3.shape[0] == 0 or img_full3.shape[1] == 0:
|
||||||
print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
|
print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user