diff --git a/codes/test.py b/codes/test.py index 63b37eff..7f3fdd11 100644 --- a/codes/test.py +++ b/codes/test.py @@ -4,11 +4,13 @@ import time import argparse from collections import OrderedDict +import os import options.options as option import utils.util as util from data.util import bgr2ycbcr import models.archs.SwitchedResidualGenerator_arch as srg -from switched_conv_util import save_attention_to_image +from switched_conv_util import save_attention_to_image, save_attention_to_image_rgb +from switched_conv import compute_attention_specificity from data import create_dataset, create_dataloader from models import create_model from tqdm import tqdm @@ -22,14 +24,37 @@ import models.networks as networks def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2): # First alteration, strip off switches one at a time. yield "naked" + + ''' for i in range(1, len(srg.switches)): srg.switches = srg.switches[:-i] yield "stripped-%i" % (i,) + ''' + + for sw in srg.switches: + sw.set_temperature(.001) + yield "specific" + + for sw in srg.switches: + sw.set_temperature(1000) + yield "normalized" + + for sw in srg.switches: + sw.set_temperature(1) + sw.switch.attention_norm = None + yield "no_anorm" return None def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix): - [save_attention_to_image(path, srg.attentions[i], srg.transformation_counts, i, "attention_" + alteration_suffix, - l_mult=5) for i in range(len(srg.attentions))] + mean_hists = [compute_attention_specificity(att, 2) for att in srg.attentions] + means = [i[0] for i in mean_hists] + hists = [torch.histc(i[1].clone().detach().cpu().flatten().float(), bins=srg.transformation_counts) for i in mean_hists] + hists = [h / torch.sum(h) for h in hists] + for i in range(len(means)): + print("%s - switch_%i_specificity" % (alteration_suffix, i), means[i]) + print("%s - switch_%i_histogram" % (alteration_suffix, i), hists[i]) + + [save_attention_to_image_rgb(path, srg.attentions[i], srg.transformation_counts, alteration_suffix, i) for i in range(len(srg.attentions))] def forward_pass(model, output_dir, alteration_suffix=''): @@ -60,7 +85,7 @@ if __name__ == "__main__": #### options torch.backends.cudnn.benchmark = True want_just_images = True - srg_analyze = True + srg_analyze = False parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/analyze_srg.yml') opt = option.parse(parser.parse_args().opt, is_train=False) @@ -106,14 +131,16 @@ if __name__ == "__main__": model_copy.load_state_dict(orig_model.state_dict()) model.netG = model_copy for alteration_suffix in alter_srg(model_copy): + alt_path = osp.join(dataset_dir, alteration_suffix) img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] - img_name = osp.splitext(osp.basename(img_path))[0] + img_name = osp.splitext(osp.basename(img_path))[0] + opt['name'] alteration_suffix += img_name + os.makedirs(alt_path, exist_ok=True) forward_pass(model, dataset_dir, alteration_suffix) - analyze_srg(model_copy, dataset_dir, alteration_suffix) + analyze_srg(model_copy, alt_path, alteration_suffix) # Reset model and do next alteration. model_copy = networks.define_G(opt).to(model.device) model_copy.load_state_dict(orig_model.state_dict()) model.netG = model_copy else: - forward_pass(model, dataset_dir) + forward_pass(model, dataset_dir, opt['name']) diff --git a/codes/train.py b/codes/train.py index 3c8a0fb8..c6f28bef 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched_lr2.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -161,7 +161,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = -1 + current_step = 0 start_epoch = 0 #### training diff --git a/sandbox.py b/sandbox.py index 18825c15..97d8232d 100644 --- a/sandbox.py +++ b/sandbox.py @@ -1,22 +1,56 @@ import torch import torchvision from PIL import Image +from pytorch_wavelets import DWTForward, DWTInverse +import torch.nn.functional as F def load_img(path): - im = Image.open(path) + im = Image.open(path).convert(mode="RGB") return torchvision.transforms.ToTensor()(im) def save_img(t, path): torchvision.utils.save_image(t, path) -img = load_img("me.png") -# add zeros to the imaginary component -img = torch.stack([img, torch.zeros_like(img)], dim=-1) -fft = torch.fft(img, signal_ndim=2) -fft_d = torch.zeros_like(fft) -for i in range(-5, 5): - diag = torch.diagonal(fft, offset=i, dim1=1, dim2=2) - diag_em = torch.diag_embed(diag, offset=i, dim1=1, dim2=2) - fft_d += diag_em -resamp_img = torch.ifft(fft_d, signal_ndim=2)[:, :, :, 0] -save_img(resamp_img, "resampled.png") \ No newline at end of file +img = load_img("pu.jpg") +img = img.unsqueeze(0) + +# Reshape image to be multiple of 32 +w, h = img.shape[2:] +w = (w // 32) * 32 +h = (h // 32) * 32 +img = F.interpolate(img, size=(w, h)) +print("Input shape:", img.shape) + +J_spec = 5 + +Yl, Yh = DWTForward(J=J_spec, mode='periodization', wave='db3')(img) +print(Yl.shape, [h.shape for h in Yh]) + +imgLR = F.interpolate(img, scale_factor=.5) +LQYl, LQYh = DWTForward(J=J_spec-1, mode='periodization', wave='db3')(imgLR) +print(LQYl.shape, [h.shape for h in LQYh]) + +for i in range(J_spec): + smd = torch.sum(Yh[i], dim=2).cpu() + save_img(smd, "high_%i.png" % (i,)) +save_img(Yl, "lo.png") + +''' +Following code reconstructs the image with different high passes cancelled out. +''' +for i in range(J_spec): + corrupted_im = [y for y in Yh] + corrupted_im[i] = torch.zeros_like(corrupted_im[i]) + im = DWTInverse(mode='periodization', wave='db3')((Yl, corrupted_im)) + save_img(im, "corrupt_%i.png" % (i,)) +im = DWTInverse(mode='periodization', wave='db3')((torch.full_like(Yl, fill_value=torch.mean(Yl)), Yh)) +save_img(im, "corrupt_im.png") + + +''' +Following code reconstructs a hybrid image with the first high pass from the HR and the rest of the data from the LR. +highpass = [Yh[0]] + LQYh +im = DWTInverse(mode='periodization', wave='db3')((LQYl, highpass)) +save_img(im, "hybrid_lrhr.png") +save_img(F.interpolate(imgLR, scale_factor=2), "upscaled.png") +''' \ No newline at end of file