Dataset modifications

This commit is contained in:
James Betker 2020-11-24 13:20:12 -07:00
parent f6098155cd
commit f3c1fc1bcd
4 changed files with 17 additions and 9 deletions

View File

@ -17,6 +17,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
self.scale = opt['scale'] if not self.for_eval else 1
self.paths = opt['paths']
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if not isinstance(self.paths, list):
self.paths = [self.paths]
@ -98,6 +99,8 @@ class BaseUnsupervisedImageDataset(data.Dataset):
def synthesize_lq(self, hs, hrefs, hmasks, hcenters):
h, w, _ = hs[0].shape
ls, lrs, lms, lcs = [], [], [], []
if self.corrupt_before_downsize and not self.for_eval:
hs = self.corruptor.corrupt_images(hs)
for hq, hq_ref, hq_mask, hq_center in zip(hs, hrefs, hmasks, hcenters):
if self.for_eval:
ls.append(hq)
@ -110,7 +113,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
# Corrupt the LQ image (only in eval mode)
if not self.for_eval:
if not self.corrupt_before_downsize and not self.for_eval:
ls = self.corruptor.corrupt_images(ls)
return ls, lrs, lms, lcs

View File

@ -50,9 +50,10 @@ if __name__ == '__main__':
'force_multiple': 32,
'scale': 2,
'eval': False,
'fixed_corruptions': ['jpeg-broad'],
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1
'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
}
ds = SingleImageDataset(opt)
@ -61,7 +62,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
#for k, v in o.items():
k = 'GT'
k = 'LQ'
v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k:

View File

@ -13,20 +13,20 @@ import torch
def main():
split_img = False
opt = {}
opt['n_thread'] = 2
opt['n_thread'] = 8
opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file'
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\images_sized'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\paired_images'
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imgset3'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_3'
opt['crop_sz'] = [512, 1024] # the size of each sub-image
opt['step'] = [512, 1024] # step of the sliding crop window
opt['thres_sz'] = 128 # size threshold
opt['resize_final_img'] = [.5, .25]
opt['only_resize'] = False
opt['vertical_split'] = True
opt['vertical_split'] = False
save_folder = opt['save_folder']
if not osp.exists(save_folder):
@ -178,10 +178,14 @@ class TiledDataset(data.Dataset):
def get(self, index, split_mode, left_img):
path = self.images[index]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img is None or len(img.shape) == 2:
return None
h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < 1024:
if min(h,w) < 512:
return None
# Greyscale not supported.
if len(img.shape) == 2: