Allow weighting of input data
This essentially allows you to give some datasets more importance than others for the purposes of reaching a more refined network.
This commit is contained in:
parent
edf0f8582e
commit
318a604405
|
@ -29,7 +29,7 @@ class LQGTDataset(data.Dataset):
|
|||
self.paths_PIX, self.sizes_PIX = None, None
|
||||
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
|
||||
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
|
||||
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
|
||||
if 'dataroot_LQ' in opt.keys():
|
||||
self.paths_LQ = []
|
||||
if isinstance(opt['dataroot_LQ'], list):
|
||||
|
|
|
@ -42,7 +42,7 @@ def _get_paths_from_lmdb(dataroot):
|
|||
return paths, sizes
|
||||
|
||||
|
||||
def get_image_paths(data_type, dataroot):
|
||||
def get_image_paths(data_type, dataroot, weights=[]):
|
||||
"""get image path list
|
||||
support lmdb or image files"""
|
||||
paths, sizes = None, None
|
||||
|
@ -52,8 +52,15 @@ def get_image_paths(data_type, dataroot):
|
|||
elif data_type == 'img':
|
||||
if isinstance(dataroot, list):
|
||||
paths = []
|
||||
for r in dataroot:
|
||||
paths.extend(_get_paths_from_images(r))
|
||||
for i in range(len(dataroot)):
|
||||
r = dataroot[i]
|
||||
extends = 1
|
||||
|
||||
# Weights have the effect of repeatedly adding the paths from the given root to the final product.
|
||||
if weights:
|
||||
extends = weights[i]
|
||||
for j in range(extends):
|
||||
paths.extend(_get_paths_from_images(r))
|
||||
paths = sorted(paths)
|
||||
else:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
|
|
Loading…
Reference in New Issue
Block a user