Allow weighting of input data

This essentially allows you to give some datasets more
importance than others for the purposes of reaching a more
refined network.
This commit is contained in:
James Betker 2020-06-04 10:05:21 -06:00
parent edf0f8582e
commit 318a604405
2 changed files with 11 additions and 4 deletions

View File

@ -29,7 +29,7 @@ class LQGTDataset(data.Dataset):
self.paths_PIX, self.sizes_PIX = None, None
self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights'])
if 'dataroot_LQ' in opt.keys():
self.paths_LQ = []
if isinstance(opt['dataroot_LQ'], list):

View File

@ -42,7 +42,7 @@ def _get_paths_from_lmdb(dataroot):
return paths, sizes
def get_image_paths(data_type, dataroot):
def get_image_paths(data_type, dataroot, weights=[]):
"""get image path list
support lmdb or image files"""
paths, sizes = None, None
@ -52,8 +52,15 @@ def get_image_paths(data_type, dataroot):
elif data_type == 'img':
if isinstance(dataroot, list):
paths = []
for r in dataroot:
paths.extend(_get_paths_from_images(r))
for i in range(len(dataroot)):
r = dataroot[i]
extends = 1
# Weights have the effect of repeatedly adding the paths from the given root to the final product.
if weights:
extends = weights[i]
for j in range(extends):
paths.extend(_get_paths_from_images(r))
paths = sorted(paths)
else:
paths = sorted(_get_paths_from_images(dataroot))