forked from mrq/DL-Art-School
Dataset modifications
This commit is contained in:
parent
f6098155cd
commit
f3c1fc1bcd
|
@ -17,6 +17,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
|
self.for_eval = opt['eval'] if 'eval' in opt.keys() else False
|
||||||
self.scale = opt['scale'] if not self.for_eval else 1
|
self.scale = opt['scale'] if not self.for_eval else 1
|
||||||
self.paths = opt['paths']
|
self.paths = opt['paths']
|
||||||
|
self.corrupt_before_downsize = opt['corrupt_before_downsize'] if 'corrupt_before_downsize' in opt.keys() else False
|
||||||
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
assert (self.target_hq_size // self.scale) % self.multiple == 0 # If we dont throw here, we get some really obscure errors.
|
||||||
if not isinstance(self.paths, list):
|
if not isinstance(self.paths, list):
|
||||||
self.paths = [self.paths]
|
self.paths = [self.paths]
|
||||||
|
@ -98,6 +99,8 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
def synthesize_lq(self, hs, hrefs, hmasks, hcenters):
|
def synthesize_lq(self, hs, hrefs, hmasks, hcenters):
|
||||||
h, w, _ = hs[0].shape
|
h, w, _ = hs[0].shape
|
||||||
ls, lrs, lms, lcs = [], [], [], []
|
ls, lrs, lms, lcs = [], [], [], []
|
||||||
|
if self.corrupt_before_downsize and not self.for_eval:
|
||||||
|
hs = self.corruptor.corrupt_images(hs)
|
||||||
for hq, hq_ref, hq_mask, hq_center in zip(hs, hrefs, hmasks, hcenters):
|
for hq, hq_ref, hq_mask, hq_center in zip(hs, hrefs, hmasks, hcenters):
|
||||||
if self.for_eval:
|
if self.for_eval:
|
||||||
ls.append(hq)
|
ls.append(hq)
|
||||||
|
@ -110,7 +113,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
lms.append(cv2.resize(hq_mask, (h // self.scale, w // self.scale), interpolation=cv2.INTER_AREA))
|
||||||
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
|
lcs.append(self.resize_point(hq_center, (h, w), ls[0].shape[:2]))
|
||||||
# Corrupt the LQ image (only in eval mode)
|
# Corrupt the LQ image (only in eval mode)
|
||||||
if not self.for_eval:
|
if not self.corrupt_before_downsize and not self.for_eval:
|
||||||
ls = self.corruptor.corrupt_images(ls)
|
ls = self.corruptor.corrupt_images(ls)
|
||||||
return ls, lrs, lms, lcs
|
return ls, lrs, lms, lcs
|
||||||
|
|
||||||
|
|
|
@ -50,9 +50,10 @@ if __name__ == '__main__':
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
'scale': 2,
|
'scale': 2,
|
||||||
'eval': False,
|
'eval': False,
|
||||||
'fixed_corruptions': ['jpeg-broad'],
|
'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'],
|
||||||
'random_corruptions': ['noise-5', 'none'],
|
'random_corruptions': ['noise-5', 'none'],
|
||||||
'num_corrupts_per_image': 1
|
'num_corrupts_per_image': 1,
|
||||||
|
'corrupt_before_downsize': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
ds = SingleImageDataset(opt)
|
ds = SingleImageDataset(opt)
|
||||||
|
@ -61,7 +62,7 @@ if __name__ == '__main__':
|
||||||
for i in range(0, len(ds)):
|
for i in range(0, len(ds)):
|
||||||
o = ds[random.randint(0, len(ds))]
|
o = ds[random.randint(0, len(ds))]
|
||||||
#for k, v in o.items():
|
#for k, v in o.items():
|
||||||
k = 'GT'
|
k = 'LQ'
|
||||||
v = o[k]
|
v = o[k]
|
||||||
#if 'LQ' in k and 'path' not in k and 'center' not in k:
|
#if 'LQ' in k and 'path' not in k and 'center' not in k:
|
||||||
#if 'full' in k:
|
#if 'full' in k:
|
||||||
|
|
|
@ -13,20 +13,20 @@ import torch
|
||||||
def main():
|
def main():
|
||||||
split_img = False
|
split_img = False
|
||||||
opt = {}
|
opt = {}
|
||||||
opt['n_thread'] = 2
|
opt['n_thread'] = 8
|
||||||
opt['compression_level'] = 90 # JPEG compression quality rating.
|
opt['compression_level'] = 90 # JPEG compression quality rating.
|
||||||
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\images_sized'
|
opt['input_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\imgset3'
|
||||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\vr\\paired_images'
|
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\imagesets\\256_with_ref_3'
|
||||||
opt['crop_sz'] = [512, 1024] # the size of each sub-image
|
opt['crop_sz'] = [512, 1024] # the size of each sub-image
|
||||||
opt['step'] = [512, 1024] # step of the sliding crop window
|
opt['step'] = [512, 1024] # step of the sliding crop window
|
||||||
opt['thres_sz'] = 128 # size threshold
|
opt['thres_sz'] = 128 # size threshold
|
||||||
opt['resize_final_img'] = [.5, .25]
|
opt['resize_final_img'] = [.5, .25]
|
||||||
opt['only_resize'] = False
|
opt['only_resize'] = False
|
||||||
opt['vertical_split'] = True
|
opt['vertical_split'] = False
|
||||||
|
|
||||||
save_folder = opt['save_folder']
|
save_folder = opt['save_folder']
|
||||||
if not osp.exists(save_folder):
|
if not osp.exists(save_folder):
|
||||||
|
@ -178,10 +178,14 @@ class TiledDataset(data.Dataset):
|
||||||
def get(self, index, split_mode, left_img):
|
def get(self, index, split_mode, left_img):
|
||||||
path = self.images[index]
|
path = self.images[index]
|
||||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
if img is None or len(img.shape) == 2:
|
||||||
|
return None
|
||||||
|
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
|
|
||||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||||
if min(h,w) < 1024:
|
if min(h,w) < 512:
|
||||||
return None
|
return None
|
||||||
# Greyscale not supported.
|
# Greyscale not supported.
|
||||||
if len(img.shape) == 2:
|
if len(img.shape) == 2:
|
||||||
|
|
0
codes/scripts/srflow_latent_space_playground.py
Normal file
0
codes/scripts/srflow_latent_space_playground.py
Normal file
Loading…
Reference in New Issue
Block a user