diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 4bcdaa94..98aa3caf 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -41,6 +41,7 @@ class SRGANModel(BaseModel): p.requires_grad = True else: self.netC = None + self.mega_batch_factor = 1 # define losses, optimizer and scheduler if self.is_train: diff --git a/codes/test.py b/codes/test.py index 2e7f1337..63b37eff 100644 --- a/codes/test.py +++ b/codes/test.py @@ -7,17 +7,62 @@ from collections import OrderedDict import options.options as option import utils.util as util from data.util import bgr2ycbcr +import models.archs.SwitchedResidualGenerator_arch as srg +from switched_conv_util import save_attention_to_image from data import create_dataset, create_dataloader from models import create_model from tqdm import tqdm import torch +import models.networks as networks + + +# Concepts: Swap transformations around. Normalize attention. Disable individual switches, both randomly and one at +# a time, starting at the last switch. Pick random regions in an image and print out the full attention vector for +# each switch. Yield an output directory name for each alteration and None when last alteration is completed. +def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2): + # First alteration, strip off switches one at a time. + yield "naked" + for i in range(1, len(srg.switches)): + srg.switches = srg.switches[:-i] + yield "stripped-%i" % (i,) + return None + +def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix): + [save_attention_to_image(path, srg.attentions[i], srg.transformation_counts, i, "attention_" + alteration_suffix, + l_mult=5) for i in range(len(srg.attentions))] + + +def forward_pass(model, output_dir, alteration_suffix=''): + model.feed_data(data, need_GT=need_GT) + model.test() + + if isinstance(model.fake_GenOut[0], tuple): + visuals = model.fake_GenOut[0][0].detach().float().cpu() + else: + visuals = model.fake_GenOut[0].detach().float().cpu() + for i in range(visuals.shape[0]): + img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i] + img_name = osp.splitext(osp.basename(img_path))[0] + + sr_img = util.tensor2img(visuals[i]) # uint8 + + # save images + suffix = alteration_suffix + if suffix: + save_img_path = osp.join(output_dir, img_name + suffix + '.png') + else: + save_img_path = osp.join(output_dir, img_name + '.png') + + util.save_img(sr_img, save_img_path) + if __name__ == "__main__": #### options torch.backends.cudnn.benchmark = True want_just_images = True + srg_analyze = True parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/test_resgen_upsample.yml') + parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/analyze_srg.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) @@ -54,40 +99,21 @@ if __name__ == "__main__": tq = tqdm(test_loader) for data in tq: need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True - model.feed_data(data, need_GT=need_GT) - model.test() - if isinstance(model.fake_H, tuple): - visuals = model.fake_H[0].detach().float().cpu() + if srg_analyze: + orig_model = model.netG + model_copy = networks.define_G(opt).to(model.device) + model_copy.load_state_dict(orig_model.state_dict()) + model.netG = model_copy + for alteration_suffix in alter_srg(model_copy): + img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] + img_name = osp.splitext(osp.basename(img_path))[0] + alteration_suffix += img_name + forward_pass(model, dataset_dir, alteration_suffix) + analyze_srg(model_copy, dataset_dir, alteration_suffix) + # Reset model and do next alteration. + model_copy = networks.define_G(opt).to(model.device) + model_copy.load_state_dict(orig_model.state_dict()) + model.netG = model_copy else: - visuals = model.fake_H.detach().float().cpu() - for i in range(visuals.shape[0]): - img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i] - img_name = osp.splitext(osp.basename(img_path))[0] - - sr_img = util.tensor2img(visuals[i]) # uint8 - - # save images - suffix = opt['suffix'] - if suffix: - save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') - else: - save_img_path = osp.join(dataset_dir, img_name + '.png') - util.save_img(sr_img, save_img_path) - - if want_just_images: - continue - - if not want_just_images and need_GT: # metrics - # Average PSNR/SSIM results - ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) - ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) - logger.info( - '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( - test_set_name, ave_psnr, ave_ssim)) - if test_results['psnr_y'] and test_results['ssim_y']: - ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) - ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) - logger.info( - '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'. - format(ave_psnr_y, ave_ssim_y)) + forward_pass(model, dataset_dir)