diff --git a/codes/data/__init__.py b/codes/data/__init__.py index f3fff285..8bb498dd 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -38,6 +38,8 @@ def create_dataset(dataset_opt): from data.Downsample_dataset import DownsampleDataset as D elif mode == 'fullimage': from data.full_image_dataset import FullImageDataset as D + elif mode == 'single_image_extensible': + from data.single_image_dataset import SingleImageDataset as D elif mode == 'combined': from data.combined_dataset import CombinedDataset as D else: diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index 40f83302..8704b327 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -6,13 +6,16 @@ import numpy as np # Iterable that reads all the images in a directory that contains a reference image, tile images and center coordinates. class ChunkWithReference: def __init__(self, opt, path): - self.opt = opt + self.reload(opt) self.path = path.path - self.ref = None # This is loaded on the fly. - self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else True self.tiles, _ = util.get_image_paths('img', path) self.centers = None + def reload(self, opt): + self.opt = opt + self.ref = None # This is loaded on the fly. + self.cache_ref = opt['cache_ref'] if 'cache_ref' in opt.keys() else False + def __getitem__(self, item): # Load centers on the fly and always cache. if self.centers is None: @@ -20,10 +23,9 @@ class ChunkWithReference: if self.cache_ref: if self.ref is None: self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) - self.centers = torch.load(osp.join(self.path, "centers.pt")) ref = self.ref else: - self.ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) + ref = util.read_img(None, osp.join(self.path, "ref.jpg"), rgb=True) tile = util.read_img(None, self.tiles[item], rgb=True) tile_id = int(osp.splitext(osp.basename(self.tiles[item]))[0]) center, tile_width = self.centers[tile_id] diff --git a/codes/data/image_corruptor.py b/codes/data/image_corruptor.py index 196114f1..935d85dc 100644 --- a/codes/data/image_corruptor.py +++ b/codes/data/image_corruptor.py @@ -10,10 +10,16 @@ from io import BytesIO class ImageCorruptor: def __init__(self, opt): self.num_corrupts = opt['num_corrupts_per_image'] if 'num_corrupts_per_image' in opt.keys() else 2 + if self.num_corrupts == 0: + return self.fixed_corruptions = opt['fixed_corruptions'] self.random_corruptions = opt['random_corruptions'] + self.blur_scale = opt['corruption_blur_scale'] if 'corruption_blur_scale' in opt.keys() else 1 def corrupt_images(self, imgs): + if self.num_corrupts == 0: + return imgs + augmentations = random.choices(self.random_corruptions, k=self.num_corrupts) # Source of entropy, which should be used across all images. rand_int_f = random.randint(1, 999999) @@ -38,11 +44,11 @@ class ImageCorruptor: img = img / 255 elif 'gaussian_blur' in aug: # Gaussian Blur - kernel = 2 * (rand_int % 3) + 1 + kernel = 2 * self.blur_scale * (rand_int % 3) + 1 img = cv2.GaussianBlur(img, (kernel, kernel), 3) elif 'motion_blur' in aug: # Motion blur - intensity = 2 * (rand_int % 3) + 1 + intensity = 2 * self.blur_scale * (rand_int % 3) + 1 angle = (rand_int // 3) % 360 k = np.zeros((intensity, intensity), dtype=np.float32) k[(intensity - 1) // 2, :] = np.ones(intensity, dtype=np.float32) @@ -52,7 +58,7 @@ class ImageCorruptor: img = cv2.filter2D(img, -1, k) elif 'smooth_blur' in aug: # Smooth blur - kernel = 2 * (rand_int % 3) + 1 + kernel = 2 * self.blur_scale * (rand_int % 3) + 1 img = cv2.blur(img, ksize=(kernel, kernel)) elif 'block_noise' in aug: # Large distortion blocks in part of an img, such as is used to mask out a face. diff --git a/codes/data/single_image_dataset.py b/codes/data/single_image_dataset.py index 6354d3a5..cdea7f4a 100644 --- a/codes/data/single_image_dataset.py +++ b/codes/data/single_image_dataset.py @@ -14,6 +14,7 @@ import torchvision.transforms.functional as F class SingleImageDataset(data.Dataset): def __init__(self, opt): + self.opt = opt self.corruptor = ImageCorruptor(opt) self.target_hq_size = opt['target_size'] if 'target_size' in opt.keys() else None self.multiple = opt['force_multiple'] if 'force_multiple' in opt.keys() else 1 @@ -33,6 +34,9 @@ class SingleImageDataset(data.Dataset): cache_path = os.path.join(path, 'cache.pth') if os.path.exists(cache_path): chunks = torch.load(cache_path) + # Update the options. + for c in chunks: + c.reload(opt) else: chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] torch.save(chunks, cache_path) @@ -101,7 +105,7 @@ class SingleImageDataset(data.Dataset): lq_ref = torch.cat([lq_ref, lq_mask], dim=0) return {'LQ': lq, 'GT': hq, 'gt_fullsize_ref': hq_ref, 'lq_fullsize_ref': lq_ref, - 'lq_center': lq_center, 'gt_center': hq_center, + 'lq_center': torch.tensor(lq_center, dtype=torch.long), 'gt_center': torch.tensor(hq_center, dtype=torch.long), 'LQ_path': path, 'GT_path': path} def __len__(self): diff --git a/codes/data_scripts/use_discriminator_as_filter.py b/codes/data_scripts/use_discriminator_as_filter.py index 6615dcdf..89df1344 100644 --- a/codes/data_scripts/use_discriminator_as_filter.py +++ b/codes/data_scripts/use_discriminator_as_filter.py @@ -60,16 +60,20 @@ if __name__ == "__main__": util.mkdir(dataset_dir) tq = tqdm(test_loader) + removed = 0 for data in tq: model.feed_data(data, need_GT=True) model.test() results = model.eval_state['discriminator_out'][0] + print(torch.mean(results), torch.max(results), torch.min(results)) for i in range(results.shape[0]): - imname = osp.basename(data['GT_path'][i]) - if results[i] < 1: - torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) - else: - torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname)) + if results[i] < .8: + os.remove(data['GT_path'][i]) + removed += 1 + #imname = osp.basename(data['GT_path'][i]) + #if results[i] > .8: + # torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname)) + #else: + # torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) - # log - logger.info('# Validation # Fea: {:.4e}'.format(fea_loss / len(test_loader))) \ No newline at end of file + print("Removed %i/%i images" % (removed, len(test_set))) \ No newline at end of file diff --git a/codes/options/options.py b/codes/options/options.py index 28afe856..297426f9 100644 --- a/codes/options/options.py +++ b/codes/options/options.py @@ -43,9 +43,12 @@ def parse(opt_path, is_train=True): dataset['mode'] = dataset['mode'].replace('_mc', '') # path - for key, path in opt['path'].items(): - if path and key in opt['path'] and key != 'strict_load': - opt['path'][key] = osp.expanduser(path) + if 'path' in opt.keys(): + for key, path in opt['path'].items(): + if path and key in opt['path'] and key != 'strict_load': + opt['path'][key] = osp.expanduser(path) + else: + opt['path'] = {} opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) if is_train: experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) diff --git a/codes/test.py b/codes/test.py index bf1061ef..f025290b 100644 --- a/codes/test.py +++ b/codes/test.py @@ -86,7 +86,6 @@ def forward_pass(model, output_dir, alteration_suffix=''): if __name__ == "__main__": #### options torch.backends.cudnn.benchmark = True - want_just_images = True srg_analyze = False parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml') diff --git a/codes/train.py b/codes/train.py index f18cb224..53c0c4f2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,9 +32,8 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_ssgr1.yml') - parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', - help='job launcher') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_blur_discriminator.yml') + parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) @@ -83,7 +82,7 @@ def main(): if resume_state is None: util.mkdir_and_rename( opt['path']['experiments_root']) # rename experiment folder if exists - util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' + util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and path is not None and 'pretrain_model' not in key and 'resume' not in key)) # config loggers. Before it, the log will not work