This commit is contained in:
James Betker 2020-08-12 08:46:15 -06:00
parent 3d0ece804b
commit bdaa67deb7
3 changed files with 82 additions and 21 deletions

View File

@ -4,11 +4,13 @@ import time
import argparse import argparse
from collections import OrderedDict from collections import OrderedDict
import os
import options.options as option import options.options as option
import utils.util as util import utils.util as util
from data.util import bgr2ycbcr from data.util import bgr2ycbcr
import models.archs.SwitchedResidualGenerator_arch as srg import models.archs.SwitchedResidualGenerator_arch as srg
from switched_conv_util import save_attention_to_image from switched_conv_util import save_attention_to_image, save_attention_to_image_rgb
from switched_conv import compute_attention_specificity
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
from models import create_model from models import create_model
from tqdm import tqdm from tqdm import tqdm
@ -22,14 +24,37 @@ import models.networks as networks
def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2): def alter_srg(srg: srg.ConfigurableSwitchedResidualGenerator2):
# First alteration, strip off switches one at a time. # First alteration, strip off switches one at a time.
yield "naked" yield "naked"
'''
for i in range(1, len(srg.switches)): for i in range(1, len(srg.switches)):
srg.switches = srg.switches[:-i] srg.switches = srg.switches[:-i]
yield "stripped-%i" % (i,) yield "stripped-%i" % (i,)
'''
for sw in srg.switches:
sw.set_temperature(.001)
yield "specific"
for sw in srg.switches:
sw.set_temperature(1000)
yield "normalized"
for sw in srg.switches:
sw.set_temperature(1)
sw.switch.attention_norm = None
yield "no_anorm"
return None return None
def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix): def analyze_srg(srg: srg.ConfigurableSwitchedResidualGenerator2, path, alteration_suffix):
[save_attention_to_image(path, srg.attentions[i], srg.transformation_counts, i, "attention_" + alteration_suffix, mean_hists = [compute_attention_specificity(att, 2) for att in srg.attentions]
l_mult=5) for i in range(len(srg.attentions))] means = [i[0] for i in mean_hists]
hists = [torch.histc(i[1].clone().detach().cpu().flatten().float(), bins=srg.transformation_counts) for i in mean_hists]
hists = [h / torch.sum(h) for h in hists]
for i in range(len(means)):
print("%s - switch_%i_specificity" % (alteration_suffix, i), means[i])
print("%s - switch_%i_histogram" % (alteration_suffix, i), hists[i])
[save_attention_to_image_rgb(path, srg.attentions[i], srg.transformation_counts, alteration_suffix, i) for i in range(len(srg.attentions))]
def forward_pass(model, output_dir, alteration_suffix=''): def forward_pass(model, output_dir, alteration_suffix=''):
@ -60,7 +85,7 @@ if __name__ == "__main__":
#### options #### options
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
want_just_images = True want_just_images = True
srg_analyze = True srg_analyze = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/analyze_srg.yml') parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/analyze_srg.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
@ -106,14 +131,16 @@ if __name__ == "__main__":
model_copy.load_state_dict(orig_model.state_dict()) model_copy.load_state_dict(orig_model.state_dict())
model.netG = model_copy model.netG = model_copy
for alteration_suffix in alter_srg(model_copy): for alteration_suffix in alter_srg(model_copy):
alt_path = osp.join(dataset_dir, alteration_suffix)
img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
img_name = osp.splitext(osp.basename(img_path))[0] img_name = osp.splitext(osp.basename(img_path))[0] + opt['name']
alteration_suffix += img_name alteration_suffix += img_name
os.makedirs(alt_path, exist_ok=True)
forward_pass(model, dataset_dir, alteration_suffix) forward_pass(model, dataset_dir, alteration_suffix)
analyze_srg(model_copy, dataset_dir, alteration_suffix) analyze_srg(model_copy, alt_path, alteration_suffix)
# Reset model and do next alteration. # Reset model and do next alteration.
model_copy = networks.define_G(opt).to(model.device) model_copy = networks.define_G(opt).to(model.device)
model_copy.load_state_dict(orig_model.state_dict()) model_copy.load_state_dict(orig_model.state_dict())
model.netG = model_copy model.netG = model_copy
else: else:
forward_pass(model, dataset_dir) forward_pass(model, dataset_dir, opt['name'])

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched_lr2.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -161,7 +161,7 @@ def main():
current_step = resume_state['iter'] current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers model.resume_training(resume_state) # handle optimizers and schedulers
else: else:
current_step = -1 current_step = 0
start_epoch = 0 start_epoch = 0
#### training #### training

View File

@ -1,22 +1,56 @@
import torch import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
from pytorch_wavelets import DWTForward, DWTInverse
import torch.nn.functional as F
def load_img(path): def load_img(path):
im = Image.open(path) im = Image.open(path).convert(mode="RGB")
return torchvision.transforms.ToTensor()(im) return torchvision.transforms.ToTensor()(im)
def save_img(t, path): def save_img(t, path):
torchvision.utils.save_image(t, path) torchvision.utils.save_image(t, path)
img = load_img("me.png") img = load_img("pu.jpg")
# add zeros to the imaginary component img = img.unsqueeze(0)
img = torch.stack([img, torch.zeros_like(img)], dim=-1)
fft = torch.fft(img, signal_ndim=2) # Reshape image to be multiple of 32
fft_d = torch.zeros_like(fft) w, h = img.shape[2:]
for i in range(-5, 5): w = (w // 32) * 32
diag = torch.diagonal(fft, offset=i, dim1=1, dim2=2) h = (h // 32) * 32
diag_em = torch.diag_embed(diag, offset=i, dim1=1, dim2=2) img = F.interpolate(img, size=(w, h))
fft_d += diag_em print("Input shape:", img.shape)
resamp_img = torch.ifft(fft_d, signal_ndim=2)[:, :, :, 0]
save_img(resamp_img, "resampled.png") J_spec = 5
Yl, Yh = DWTForward(J=J_spec, mode='periodization', wave='db3')(img)
print(Yl.shape, [h.shape for h in Yh])
imgLR = F.interpolate(img, scale_factor=.5)
LQYl, LQYh = DWTForward(J=J_spec-1, mode='periodization', wave='db3')(imgLR)
print(LQYl.shape, [h.shape for h in LQYh])
for i in range(J_spec):
smd = torch.sum(Yh[i], dim=2).cpu()
save_img(smd, "high_%i.png" % (i,))
save_img(Yl, "lo.png")
'''
Following code reconstructs the image with different high passes cancelled out.
'''
for i in range(J_spec):
corrupted_im = [y for y in Yh]
corrupted_im[i] = torch.zeros_like(corrupted_im[i])
im = DWTInverse(mode='periodization', wave='db3')((Yl, corrupted_im))
save_img(im, "corrupt_%i.png" % (i,))
im = DWTInverse(mode='periodization', wave='db3')((torch.full_like(Yl, fill_value=torch.mean(Yl)), Yh))
save_img(im, "corrupt_im.png")
'''
Following code reconstructs a hybrid image with the first high pass from the HR and the rest of the data from the LR.
highpass = [Yh[0]] + LQYh
im = DWTInverse(mode='periodization', wave='db3')((LQYl, highpass))
save_img(im, "hybrid_lrhr.png")
save_img(F.interpolate(imgLR, scale_factor=2), "upscaled.png")
'''