diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index 6c21f7a2..b9acebeb 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -4,11 +4,15 @@ import torch import numpy as np # Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates. +from utils.util import opt_get + + class ChunkWithReference: def __init__(self, opt, path): self.path = path.path self.tiles, _ = util.get_image_paths('img', self.path) - self.strict = opt['strict'] if 'strict' in opt.keys() else True + self.need_metadata = opt_get(opt, ['strict'], False) or opt_get(opt, ['needs_metadata'], False) + self.need_ref = opt_get(opt, ['need_ref'], False) if 'ignore_first' in opt.keys(): self.tiles = self.tiles[opt['ignore_first']:] @@ -21,18 +25,19 @@ class ChunkWithReference: def __getitem__(self, item): tile = self.read_image_or_get_zero(self.tiles[item]) - if osp.exists(osp.join(self.path, "ref.jpg")): + if self.need_ref and osp.exists(osp.join(self.path, "ref.jpg")): tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) - centers = torch.load(osp.join(self.path, "centers.pt")) ref = self.read_image_or_get_zero(osp.join(self.path, "ref.jpg")) - if tile_id in centers.keys(): - center, tile_width = centers[tile_id] - elif self.strict: - print("Could not find the given tile id in the accompanying centers.pt. This generally means that " - "centers.pt was overwritten at some point e.g. by duplicate data. If you don't care about tile " - "centers, consider passing strict=false to the dataset options. (Note: you must re-build your" - "caches for this setting change to take effect.)") - raise FileNotFoundError(tile_id, self.tiles[item]) + if self.need_metadata: + centers = torch.load(osp.join(self.path, "centers.pt")) + if tile_id in centers.keys(): + center, tile_width = centers[tile_id] + else: + print("Could not find the given tile id in the accompanying centers.pt. This generally means that " + "centers.pt was overwritten at some point e.g. by duplicate data. If you don't care about tile " + "centers, consider passing strict=false to the dataset options. (Note: you must re-build your" + "caches for this setting change to take effect.)") + raise FileNotFoundError(tile_id, self.tiles[item]) else: center = torch.tensor([128, 128], dtype=torch.long) tile_width = 256 diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 345423f3..16f885d2 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -119,7 +119,7 @@ class ImageCorruptor: raise NotImplementedError("specified jpeg corruption doesn't exist") # JPEG compression qf = (rand_int % range + lo) - # cv2's jpeg compression is "odd". It introduces artifacts. Use PIL instead. + # Use PIL to perform a mock compression to a data buffer, then swap back to cv2. img = (img * 255).astype(np.uint8) img = Image.fromarray(img) buffer = BytesIO() diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index d5ae7d88..4048f197 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -50,7 +50,6 @@ if __name__ == '__main__': 'force_multiple': 32, 'scale': 2, 'eval': False, - 'strict': False, 'fixed_corruptions': ['jpeg-broad', 'gaussian_blur'], 'random_corruptions': ['noise-5', 'none'], 'num_corrupts_per_image': 1, diff --git a/codes/train.py b/codes/train.py index 93045042..47aa36dd 100644 --- a/codes/train.py +++ b/codes/train.py @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_rrdb4x_23bl_opt.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_xx_faces_glean.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()