From efbf6b737b7a2d6aa90215301c7d0b7036435b8a Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 2 Oct 2020 08:58:34 -0600 Subject: [PATCH] Update validate_data to work with SingleImageDataset --- codes/data/single_image_dataset.py | 6 ++++++ codes/data_scripts/validate_data.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index d1a15f7e..3bf5b0cb 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -11,6 +11,12 @@ class SingleImageDataset(BaseUnsupervisedImageDataset): def __init__(self, opt): super(SingleImageDataset, self).__init__(opt) + def get_paths(self): + for i in range(len(self)): + chunk_ind = bisect_left(self.starting_indices, i) + chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1 + yield self.chunks[chunk_ind].tiles[i-self.starting_indices[chunk_ind]] + def __getitem__(self, item): chunk_ind = bisect_left(self.starting_indices, item) chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 diff --git a/codes/data_scripts/validate_data.py b/codes/data_scripts/validate_data.py index 4bb2fe04..9789d7bc 100644 --- a/codes/data_scripts/validate_data.py +++ b/codes/data_scripts/validate_data.py @@ -15,7 +15,7 @@ from skimage import io def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_feature_net.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_exd_imgset_spsr7.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -52,7 +52,7 @@ def main(): len(train_set), train_size)) assert train_loader is not None - tq_ldr = tqdm(train_set.paths_GT) + tq_ldr = tqdm(train_set.get_paths()) for path in tq_ldr: try: _ = io.imread(path)