Update validate_data to work with SingleImageDataset
This commit is contained in:
parent
35469f08e2
commit
efbf6b737b
|
@ -11,6 +11,12 @@ class SingleImageDataset(BaseUnsupervisedImageDataset):
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
super(SingleImageDataset, self).__init__(opt)
|
super(SingleImageDataset, self).__init__(opt)
|
||||||
|
|
||||||
|
def get_paths(self):
|
||||||
|
for i in range(len(self)):
|
||||||
|
chunk_ind = bisect_left(self.starting_indices, i)
|
||||||
|
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == i else chunk_ind-1
|
||||||
|
yield self.chunks[chunk_ind].tiles[i-self.starting_indices[chunk_ind]]
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
chunk_ind = bisect_left(self.starting_indices, item)
|
chunk_ind = bisect_left(self.starting_indices, item)
|
||||||
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
chunk_ind = chunk_ind if chunk_ind < len(self.starting_indices) and self.starting_indices[chunk_ind] == item else chunk_ind-1
|
||||||
|
|
|
@ -15,7 +15,7 @@ from skimage import io
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_feature_net.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_exd_imgset_spsr7.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
@ -52,7 +52,7 @@ def main():
|
||||||
len(train_set), train_size))
|
len(train_set), train_size))
|
||||||
assert train_loader is not None
|
assert train_loader is not None
|
||||||
|
|
||||||
tq_ldr = tqdm(train_set.paths_GT)
|
tq_ldr = tqdm(train_set.get_paths())
|
||||||
for path in tq_ldr:
|
for path in tq_ldr:
|
||||||
try:
|
try:
|
||||||
_ = io.imread(path)
|
_ = io.imread(path)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user