Update validate_data to work with SingleImageDataset

This commit is contained in:
James Betker 2020-10-02 08:58:34 -06:00
parent 35469f08e2
commit efbf6b737b
2 changed files with 8 additions and 2 deletions

View File

@ -11,6 +11,12 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
def __init__(self, opt): def __init__(self, opt):
super(SingleImageDataset, self).__init__(opt) super(SingleImageDataset, self).__init__(opt)
def get_paths(self):
for i in range(len(self)):
chunk_ind = bisect_left(self.starting_indices, i)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1
yield self.chunks[chunk_ind].tiles[i-self.starting_indices[chunk_ind]]
def __getitem__(self, item): def __getitem__(self, item):
chunk_ind = bisect_left(self.starting_indices, item) chunk_ind = bisect_left(self.starting_indices, item)
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1 chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1

View File

@ -15,7 +15,7 @@ from skimage import io
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_feature_net.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_exd_imgset_spsr7.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -52,7 +52,7 @@ def main():
len(train_set), train_size)) len(train_set), train_size))
assert train_loader is not None assert train_loader is not None
tq_ldr = tqdm(train_set.paths_GT) tq_ldr = tqdm(train_set.get_paths())
for path in tq_ldr: for path in tq_ldr:
try: try:
_ = io.imread(path) _ = io.imread(path)