Fix datasets

This commit is contained in:
James Betker 2020-11-26 11:50:38 -07:00
parent 5edaf085e0
commit 45a489110f
4 changed files with 5 additions and 1 deletions

0
codes/data/README.md Normal file
View File

View File

@ -84,6 +84,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
h, w = self.target_hq_size, self.target_hq_size
else:
hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted = imgs_hq, refs_hq, masks_hq, centers_hq
hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted] # This is done implicitly above..
hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image.
if h % hq_multiple != 0 or w % hq_multiple != 0:
hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], []

View File

@ -150,6 +150,9 @@ class TiledDataset(data.Dataset):
thres_sz = self.opt['thres_sz']
h, w, c = img.shape
if crop_sz > h:
return []
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_adrianna_srflow8x.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()