Fix bug with multiscale_dataset
This commit is contained in:
parent
eb7df63592
commit
ea8c20c0e2
|
@ -13,6 +13,24 @@ import torchvision.transforms.functional as F
|
|||
from data.image_corruptor import ImageCorruptor
|
||||
|
||||
|
||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||
# offset from center is chosen on a normal probability curve.
|
||||
def get_square_image(image):
|
||||
h, w, _ = image.shape
|
||||
if h == w:
|
||||
return image
|
||||
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
||||
if h > w:
|
||||
diff = h - w
|
||||
center = diff // 2
|
||||
top = max(int(center + offset * (center - 2)), 0)
|
||||
return image[top:top + w, :, :]
|
||||
else:
|
||||
diff = w - h
|
||||
center = diff // 2
|
||||
left = max(int(center + offset * (center - 2)), 0)
|
||||
return image[:, left:left + h, :]
|
||||
|
||||
class MultiScaleDataset(data.Dataset):
|
||||
def __init__(self, opt):
|
||||
super(MultiScaleDataset, self).__init__()
|
||||
|
@ -25,23 +43,6 @@ class MultiScaleDataset(data.Dataset):
|
|||
self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']])
|
||||
self.corruptor = ImageCorruptor(opt)
|
||||
|
||||
# Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping
|
||||
# offset from center is chosen on a normal probability curve.
|
||||
def get_square_image(self, image):
|
||||
h, w, _ = image.shape
|
||||
if h == w:
|
||||
return image
|
||||
offset = max(min(np.random.normal(scale=.3), 1.0), -1.0)
|
||||
if h > w:
|
||||
diff = h - w
|
||||
center = diff // 2
|
||||
top = int(center + offset * (center - 2))
|
||||
return image[top:top+w, :, :]
|
||||
else:
|
||||
diff = w - h
|
||||
center = diff // 2
|
||||
left = int(center + offset * (center - 2))
|
||||
return image[:, left:left+h, :]
|
||||
|
||||
def recursively_extract_patches(self, input_img, result_list, depth):
|
||||
if depth >= self.num_scales:
|
||||
|
@ -62,7 +63,7 @@ class MultiScaleDataset(data.Dataset):
|
|||
loaded_img = util.read_img(None, full_path, None)
|
||||
img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0]
|
||||
img_full2 = util.augment([img_full1], True, True)[0]
|
||||
img_full3 = self.get_square_image(img_full2)
|
||||
img_full3 = get_square_image(img_full2)
|
||||
# This error crops up from time to time. I suspect an issue with util.read_img.
|
||||
if img_full3.shape[0] == 0 or img_full3.shape[1] == 0:
|
||||
print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
|
||||
|
|
Loading…
Reference in New Issue
Block a user