From 45a489110f7aa1027ab5503963e093e420f8ff60 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 26 Nov 2020 11:50:38 -0700 Subject: [PATCH] Fix datasets --- codes/data/README.md | 0 codes/data/base_unsupervised_image_dataset.py | 1 + codes/scripts/extract_subimages_with_ref.py | 3 +++ codes/train.py | 2 +- 4 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 codes/data/README.md diff --git a/codes/data/README.md b/codes/data/README.md new file mode 100644 index 00000000..e69de29b diff --git a/codes/data/base_unsupervised_image_dataset.py b/codes/data/base_unsupervised_image_dataset.py index 20f2805b..53d2ca9c 100644 --- a/codes/data/base_unsupervised_image_dataset.py +++ b/codes/data/base_unsupervised_image_dataset.py @@ -84,6 +84,7 @@ class BaseUnsupervisedImageDataset(data.Dataset): h, w = self.target_hq_size, self.target_hq_size else: hqs_adjusted, hq_refs_adjusted, hq_masks_adjusted, hq_centers_adjusted = imgs_hq, refs_hq, masks_hq, centers_hq + hq_masks_adjusted = [m.squeeze(-1) for m in hq_masks_adjusted] # This is done implicitly above.. hq_multiple = self.multiple * self.scale # Multiple must apply to LQ image. if h % hq_multiple != 0 or w % hq_multiple != 0: hqs_conformed, hq_refs_conformed, hq_masks_conformed, hq_centers_conformed = [], [], [], [] diff --git a/codes/scripts/extract_subimages_with_ref.py b/codes/scripts/extract_subimages_with_ref.py index 8c0460c1..de888043 100644 --- a/codes/scripts/extract_subimages_with_ref.py +++ b/codes/scripts/extract_subimages_with_ref.py @@ -150,6 +150,9 @@ class TiledDataset(data.Dataset): thres_sz = self.opt['thres_sz'] h, w, c = img.shape + if crop_sz > h: + return [] + h_space = np.arange(0, h - crop_sz + 1, step) if h - (h_space[-1] + crop_sz) > thres_sz: h_space = np.append(h_space, h - crop_sz) diff --git a/codes/train.py b/codes/train.py index 251a079d..cc83a3d2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_using_rrdb_features.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_adrianna_srflow8x.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()