From 318a60440546dd516092ad81c5dbc2f0ccf87177 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 4 Jun 2020 10:05:21 -0600 Subject: [PATCH] Allow weighting of input data This essentially allows you to give some datasets more importance than others for the purposes of reaching a more refined network. --- codes/data/LQGT_dataset.py | 2 +- codes/data/util.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 944c1a83..aae2ff89 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -29,7 +29,7 @@ class LQGTDataset(data.Dataset): self.paths_PIX, self.sizes_PIX = None, None self.LQ_env, self.GT_env, self.PIX_env = None, None, None # environments for lmdbs - self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) + self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'], opt['dataroot_GT_weights']) if 'dataroot_LQ' in opt.keys(): self.paths_LQ = [] if isinstance(opt['dataroot_LQ'], list): diff --git a/codes/data/util.py b/codes/data/util.py index 719c65f3..4b7f5fc4 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -42,7 +42,7 @@ def _get_paths_from_lmdb(dataroot): return paths, sizes -def get_image_paths(data_type, dataroot): +def get_image_paths(data_type, dataroot, weights=[]): """get image path list support lmdb or image files""" paths, sizes = None, None @@ -52,8 +52,15 @@ def get_image_paths(data_type, dataroot): elif data_type == 'img': if isinstance(dataroot, list): paths = [] - for r in dataroot: - paths.extend(_get_paths_from_images(r)) + for i in range(len(dataroot)): + r = dataroot[i] + extends = 1 + + # Weights have the effect of repeatedly adding the paths from the given root to the final product. + if weights: + extends = weights[i] + for j in range(extends): + paths.extend(_get_paths_from_images(r)) paths = sorted(paths) else: paths = sorted(_get_paths_from_images(dataroot))