forked from mrq/DL-Art-School
Allow weighting of input data
This essentially allows you to give some datasets more importance than others for the purposes of reaching a more refined network.
This commit is contained in:
parent
edf0f8582e
commit
318a604405
|
@ -29,7 +29,7 @@ class LQGTDataset(data.Dataset):
|
||||||
self.paths_PIX, self.sizes_PIX = None, None
|
self.paths_PIX, self.sizes_PIX = None, None
|
||||||
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
|
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
|
||||||
|
|
||||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
|
||||||
if 'dataroot_LQ' in opt.keys():
|
if 'dataroot_LQ' in opt.keys():
|
||||||
self.paths_LQ = []
|
self.paths_LQ = []
|
||||||
if isinstance(opt['dataroot_LQ'], list):
|
if isinstance(opt['dataroot_LQ'], list):
|
||||||
|
|
|
@ -42,7 +42,7 @@ def _get_paths_from_lmdb(dataroot):
|
||||||
return paths, sizes
|
return paths, sizes
|
||||||
|
|
||||||
|
|
||||||
def get_image_paths(data_type, dataroot):
|
def get_image_paths(data_type, dataroot, weights=[]):
|
||||||
"""get image path list
|
"""get image path list
|
||||||
support lmdb or image files"""
|
support lmdb or image files"""
|
||||||
paths, sizes = None, None
|
paths, sizes = None, None
|
||||||
|
@ -52,8 +52,15 @@ def get_image_paths(data_type, dataroot):
|
||||||
elif data_type == 'img':
|
elif data_type == 'img':
|
||||||
if isinstance(dataroot, list):
|
if isinstance(dataroot, list):
|
||||||
paths = []
|
paths = []
|
||||||
for r in dataroot:
|
for i in range(len(dataroot)):
|
||||||
paths.extend(_get_paths_from_images(r))
|
r = dataroot[i]
|
||||||
|
extends = 1
|
||||||
|
|
||||||
|
# Weights have the effect of repeatedly adding the paths from the given root to the final product.
|
||||||
|
if weights:
|
||||||
|
extends = weights[i]
|
||||||
|
for j in range(extends):
|
||||||
|
paths.extend(_get_paths_from_images(r))
|
||||||
paths = sorted(paths)
|
paths = sorted(paths)
|
||||||
else:
|
else:
|
||||||
paths = sorted(_get_paths_from_images(dataroot))
|
paths = sorted(_get_paths_from_images(dataroot))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user