From fb2cfc795bb98db28500100ea21e35cfb392a7bb Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 16 Dec 2020 09:42:50 -0700 Subject: [PATCH] Update requirements, add image_patch_classifier tool --- codes/requirements.txt | 3 +- codes/test_image_patch_classifier.py | 74 ++++++++++++++++++++++++++++ codes/train.py | 2 +- 3 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 codes/test_image_patch_classifier.py diff --git a/codes/requirements.txt b/codes/requirements.txt index fd6ec31b..6684409f 100644 --- a/codes/requirements.txt +++ b/codes/requirements.txt @@ -14,4 +14,5 @@ tensorboard pytorch_fid kornia linear_attention_transformer -vector_quantize_pytorch \ No newline at end of file +vector_quantize_pytorch +orjson \ No newline at end of file diff --git a/codes/test_image_patch_classifier.py b/codes/test_image_patch_classifier.py new file mode 100644 index 00000000..2ee39f90 --- /dev/null +++ b/codes/test_image_patch_classifier.py @@ -0,0 +1,74 @@ +import os.path as osp +import logging +import time +import argparse +from collections import OrderedDict + +import os + +import utils +import utils.options as option +import utils.util as util +from data.util import bgr2ycbcr +import models.archs.SwitchedResidualGenerator_arch as srg +from models.ExtensibleTrainer import ExtensibleTrainer +from switched_conv.switched_conv_util import save_attention_to_image, save_attention_to_image_rgb +from switched_conv.switched_conv import compute_attention_specificity +from data import create_dataset, create_dataloader +from tqdm import tqdm +import torch +import models.networks as networks +import torchvision + + +if __name__ == "__main__": + #### options + torch.backends.cudnn.benchmark = True + want_metrics = False + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/train_imgset_structural_classifier.yml') + opt = option.parse(parser.parse_args().opt, is_train=False) + opt = option.dict_to_nonedict(opt) + utils.util.loaded_options = opt + + util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) + util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + + #### Create test dataset and dataloader + test_loaders = [] + for phase, dataset_opt in sorted(opt['datasets'].items()): + dataset_opt['dataset']['includes_labels'] = False + del dataset_opt['dataset']['labeler'] + test_set = create_dataset(dataset_opt) + if hasattr(test_set, 'wrapped_dataset'): + test_set = test_set.wrapped_dataset + test_loader = create_dataloader(test_set, dataset_opt, opt) + logger.info('Number of test images: {:d}'.format(len(test_set))) + test_loaders.append(test_loader) + + model = ExtensibleTrainer(opt) + gen = model.netsG['generator'] + label_to_search_for = 4 + + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + test_start_time = time.time() + dataset_dir = osp.join(opt['path']['results_root'], opt['name']) + util.mkdir(dataset_dir) + + tq = tqdm(test_loader) + step = 1 + for data in tq: + hq = data['hq'].to('cuda') + res = gen(hq) + res = torch.nn.functional.interpolate(res, size=hq.shape[2:], mode="nearest") + res_lbl = res[:, label_to_search_for, :, :].unsqueeze(1) + res_lbl_mask = (1.0 * (res_lbl > .5))*.5 + .5 + hq = hq * res_lbl_mask + torchvision.utils.save_image(hq, os.path.join(dataset_dir, "%i.png" % (step,))) + step += 1 diff --git a/codes/train.py b/codes/train.py index 2e0ee82f..53144678 100644 --- a/codes/train.py +++ b/codes/train.py @@ -293,7 +293,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgsetext_rrdb_bigboi_psnr_4x.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_structural_classifier.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()