Fix datasets
This commit is contained in:
parent
5edaf085e0
commit
45a489110f
0
codes/data/README.md
Normal file
0
codes/data/README.md
Normal file
|
@ -84,6 +84,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
h, w = self.target_hq_size, self.target_hq_size
|
||||
else:
|
||||
hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted = imgs_hq, refs_hq, masks_hq, centers_hq
|
||||
hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted] # This is done implicitly above..
|
||||
hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image.
|
||||
if h % hq_multiple != 0 or w % hq_multiple != 0:
|
||||
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []
|
||||
|
|
|
@ -150,6 +150,9 @@ class TiledDataset(data.Dataset):
|
|||
thres_sz = self.opt['thres_sz']
|
||||
h, w, c = img.shape
|
||||
|
||||
if crop_sz > h:
|
||||
return []
|
||||
|
||||
h_space = np.arange(0, h - crop_sz + 1, step)
|
||||
if h - (h_space[-1] + crop_sz) > thres_sz:
|
||||
h_space = np.append(h_space, h - crop_sz)
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_adrianna_srflow8x.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user