Fix bug that causes multiscale dataset to crash

This commit is contained in:
James Betker 2020-10-30 14:01:24 -06:00
parent 74738489b9
commit b24ff3c88d

View File

@ -59,11 +59,16 @@ class MultiScaleDataset(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
# get full size image # get full size image
full_path = self.paths_hq[index % len(self.paths_hq)] full_path = self.paths_hq[index % len(self.paths_hq)]
img_full = util.read_img(None, full_path, None) loaded_img = util.read_img(None, full_path, None)
img_full = util.channel_convert(img_full.shape[2], 'RGB', [img_full])[0] img_full1 = util.channel_convert(loaded_img.shape[2], 'RGB', [loaded_img])[0]
img_full = util.augment([img_full], True, True)[0] img_full2 = util.augment([img_full1], True, True)[0]
img_full = self.get_square_image(img_full) img_full3 = self.get_square_image(img_full2)
img_full = cv2.resize(img_full, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA) # This error crops up from time to time. I suspect an issue with util.read_img.
if img_full3.shape[0] == 0 or img_full3.shape[1] == 0:
print("Error with image: %s. Loaded image shape: %s" % (full_path,str(loaded_img.shape)), str(img_full1.shape), str(img_full2.shape), str(img_full3.shape))
# Attempt to recover by just using a fixed array of zeros, which the downstream networks should be fine training against, within reason.
img_full3 = np.zeros((1024,1024,3), dtype=np.int)
img_full = cv2.resize(img_full3, (self.hq_size_cap, self.hq_size_cap), interpolation=cv2.INTER_AREA)
patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)] patches_hq = [cv2.resize(img_full, (self.tile_size, self.tile_size), interpolation=cv2.INTER_AREA)]
self.recursively_extract_patches(img_full, patches_hq, 1) self.recursively_extract_patches(img_full, patches_hq, 1)
# Image corruption is applied against the full size image for this dataset. # Image corruption is applied against the full size image for this dataset.