diff --git a/codes/data/util.py b/codes/data/util.py index 181896b1..719c65f3 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -50,7 +50,13 @@ def get_image_paths(data_type, dataroot): if data_type == 'lmdb': paths, sizes = _get_paths_from_lmdb(dataroot) elif data_type == 'img': - paths = sorted(_get_paths_from_images(dataroot)) + if isinstance(dataroot, list): + paths = [] + for r in dataroot: + paths.extend(_get_paths_from_images(r)) + paths = sorted(paths) + else: + paths = sorted(_get_paths_from_images(dataroot)) else: raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) return paths, sizes