From 9963b3720033ec7fc5909a23c2c530c1b4b1e605 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 17 Sep 2020 13:30:32 -0600 Subject: [PATCH] Add a new script for loading a discriminator network and using it to filter images --- .../use_discriminator_as_filter.py | 75 +++++++++++++++++++ codes/models/steps/injectors.py | 26 +++++++ 2 files changed, 101 insertions(+) create mode 100644 codes/data_scripts/use_discriminator_as_filter.py diff --git a/codes/data_scripts/use_discriminator_as_filter.py b/codes/data_scripts/use_discriminator_as_filter.py new file mode 100644 index 00000000..6615dcdf --- /dev/null +++ b/codes/data_scripts/use_discriminator_as_filter.py @@ -0,0 +1,75 @@ +import os.path as osp +import logging +import time +import argparse +from collections import OrderedDict + +import os +import options.options as option +import utils.util as util +from data.util import bgr2ycbcr +import models.archs.SwitchedResidualGenerator_arch as srg +from switched_conv_util import save_attention_to_image, save_attention_to_image_rgb +from switched_conv import compute_attention_specificity +from data import create_dataset, create_dataloader +from models import create_model +from tqdm import tqdm +import torch +import models.networks as networks +import shutil +import torchvision + + +if __name__ == "__main__": + bin_path = "f:\\binned" + good_path = "f:\\good" + os.makedirs(bin_path, exist_ok=True) + os.makedirs(good_path, exist_ok=True) + + + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../../options/discriminator_filter.yml') + opt = option.parse(parser.parse_args().opt, is_train=False) + opt = option.dict_to_nonedict(opt) + opt['dist'] = False + + util.mkdirs( + (path for key, path in opt['path'].items() + if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) + util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, + screen=True, tofile=True) + logger = logging.getLogger('base') + logger.info(option.dict2str(opt)) + + #### Create test dataset and dataloader + test_loaders = [] + for phase, dataset_opt in sorted(opt['datasets'].items()): + test_set = create_dataset(dataset_opt) + test_loader = create_dataloader(test_set, dataset_opt, opt=opt) + logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) + test_loaders.append(test_loader) + + model = create_model(opt) + fea_loss = 0 + for test_loader in test_loaders: + test_set_name = test_loader.dataset.opt['name'] + logger.info('\nTesting [{:s}]...'.format(test_set_name)) + test_start_time = time.time() + dataset_dir = osp.join(opt['path']['results_root'], test_set_name) + util.mkdir(dataset_dir) + + tq = tqdm(test_loader) + for data in tq: + model.feed_data(data, need_GT=True) + model.test() + results = model.eval_state['discriminator_out'][0] + for i in range(results.shape[0]): + imname = osp.basename(data['GT_path'][i]) + if results[i] < 1: + torchvision.utils.save_image(data['GT'][i], osp.join(bin_path, imname)) + else: + torchvision.utils.save_image(data['GT'][i], osp.join(good_path, imname)) + + # log + logger.info('# Validation # Fea: {:.4e}'.format(fea_loss / len(test_loader))) \ No newline at end of file diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 897f854a..ab1101a1 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -7,6 +7,8 @@ def create_injector(opt_inject, env): type = opt_inject['type'] if type == 'generator': return ImageGeneratorInjector(opt_inject, env) + elif type == 'discriminator': + return DiscriminatorInjector(opt_inject, env) elif type == 'scheduled_scalar': return ScheduledScalarInjector(opt_inject, env) elif type == 'img_grad': @@ -60,6 +62,30 @@ class ImageGeneratorInjector(Injector): return new_state +# Injects a result from a discriminator network into the state. +class DiscriminatorInjector(Injector): + def __init__(self, opt, env): + super(DiscriminatorInjector, self).__init__(opt, env) + + def forward(self, state): + d = self.env['discriminators'][self.opt['discriminator']] + if isinstance(self.input, list): + params = [state[i] for i in self.input] + results = d(*params) + else: + results = d(state[self.input]) + new_state = {} + if isinstance(self.output, list): + # Only dereference tuples or lists, not tensors. + assert isinstance(results, list) or isinstance(results, tuple) + for i, k in enumerate(self.output): + new_state[k] = results[i] + else: + new_state[self.output] = results + + return new_state + + # Creates an image gradient from [in] and injects it into [out] class ImageGradientInjector(Injector): def __init__(self, opt, env):