Dataset modifications

This commit is contained in:
James Betker 2020-11-24 13:20:12 -07:00
parent f6098155cd
commit f3c1fc1bcd
4 changed files with 17 additions and 9 deletions

View File

@ -17,6 +17,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
self.scale = opt['scale'] if not self.for_eval else 1 self.scale = opt['scale'] if not self.for_eval else 1
self.paths = opt['paths'] self.paths = opt['paths']
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors. assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
if not isinstance(self.paths, list): if not isinstance(self.paths, list):
self.paths = [self.paths] self.paths = [self.paths]
@ -98,6 +99,8 @@ class BaseUnsupervisedImageDataset(data.Dataset):
def synthesize_lq(self, hs, hrefs, hmasks, hcenters): def synthesize_lq(self, hs, hrefs, hmasks, hcenters):
h, w, _ = hs[0].shape h, w, _ = hs[0].shape
ls, lrs, lms, lcs = [], [], [], [] ls, lrs, lms, lcs = [], [], [], []
if self.corrupt_before_downsize and not self.for_eval:
hs = self.corruptor.corrupt_images(hs)
for hq, hq_ref, hq_mask, hq_center in zip(hs, hrefs, hmasks, hcenters): for hq, hq_ref, hq_mask, hq_center in zip(hs, hrefs, hmasks, hcenters):
if self.for_eval: if self.for_eval:
ls.append(hq) ls.append(hq)
@ -110,7 +113,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA)) lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2])) lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
# Corrupt the LQ image (only in eval mode) # Corrupt the LQ image (only in eval mode)
if not self.for_eval: if not self.corrupt_before_downsize and not self.for_eval:
ls = self.corruptor.corrupt_images(ls) ls = self.corruptor.corrupt_images(ls)
return ls, lrs, lms, lcs return ls, lrs, lms, lcs

View File

@ -50,9 +50,10 @@ if __name__ == '__main__':
'force_multiple': 32, 'force_multiple': 32,
'scale': 2, 'scale': 2,
'eval': False, 'eval': False,
'fixed_corruptions': ['jpeg-broad'], 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
'random_corruptions': ['noise-5', 'none'], 'random_corruptions': ['noise-5', 'none'],
'num_corrupts_per_image': 1 'num_corrupts_per_image': 1,
'corrupt_before_downsize': True,
} }
ds = SingleImageDataset(opt) ds = SingleImageDataset(opt)
@ -61,7 +62,7 @@ if __name__ == '__main__':
for i in range(0, len(ds)): for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))] o = ds[random.randint(0, len(ds))]
#for k, v in o.items(): #for k, v in o.items():
k = 'GT' k = 'LQ'
v = o[k] v = o[k]
#if 'LQ' in k and 'path' not in k and 'center' not in k: #if 'LQ' in k and 'path' not in k and 'center' not in k:
#if 'full' in k: #if 'full' in k:

View File

@ -13,20 +13,20 @@ import torch
def main(): def main():
split_img = False split_img = False
opt = {} opt = {}
opt['n_thread'] = 2 opt['n_thread'] = 8
opt['compression_level'] = 90 # JPEG compression quality rating. opt['compression_level'] = 90 # JPEG compression quality rating.
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed. # compression time. If read raw images during training, use 0 for faster IO speed.
opt['dest'] = 'file' opt['dest'] = 'file'
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\images_sized' opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imgset3'
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\paired_images' opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_3'
opt['crop_sz'] = [512, 1024] # the size of each sub-image opt['crop_sz'] = [512, 1024] # the size of each sub-image
opt['step'] = [512, 1024] # step of the sliding crop window opt['step'] = [512, 1024] # step of the sliding crop window
opt['thres_sz'] = 128 # size threshold opt['thres_sz'] = 128 # size threshold
opt['resize_final_img'] = [.5, .25] opt['resize_final_img'] = [.5, .25]
opt['only_resize'] = False opt['only_resize'] = False
opt['vertical_split'] = True opt['vertical_split'] = False
save_folder = opt['save_folder'] save_folder = opt['save_folder']
if not osp.exists(save_folder): if not osp.exists(save_folder):
@ -178,10 +178,14 @@ class TiledDataset(data.Dataset):
def get(self, index, split_mode, left_img): def get(self, index, split_mode, left_img):
path = self.images[index] path = self.images[index]
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img is None or len(img.shape) == 2:
return None
h, w, c = img.shape h, w, c = img.shape
# Uncomment to filter any image that doesnt meet a threshold size. # Uncomment to filter any image that doesnt meet a threshold size.
if min(h,w) < 1024: if min(h,w) < 512:
return None return None
# Greyscale not supported. # Greyscale not supported.
if len(img.shape) == 2: if len(img.shape) == 2: