From 3ab39f0d222f40953557450edf29ff8b545be901 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 5 Aug 2020 10:01:24 -0600 Subject: [PATCH] Several new spsr nets --- codes/models/archs/SPSR_arch.py | 217 +++++++++++++++++- .../archs/SwitchedResidualGenerator_arch.py | 7 +- codes/models/networks.py | 3 + codes/train.py | 2 +- sandbox.py | 22 ++ 5 files changed, 247 insertions(+), 4 deletions(-) create mode 100644 sandbox.py diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 1757d8c9..56eb944b 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -4,7 +4,10 @@ import torch.nn as nn import torch.nn.functional as F from models.archs import SPSR_util as B from .RRDBNet_arch import RRDB -from models.archs.arch_util import ConvGnLelu, ExpansionBlock, UpconvBlock +from models.archs.arch_util import ConvGnLelu, UpconvBlock +from models.archs.SwitchedResidualGenerator_arch import MultiConvBlock, ConvBasisMultiplexer, ConfigurableSwitchComputer +from switched_conv_util import save_attention_to_image_rgb +import functools class ImageGradient(nn.Module): @@ -250,6 +253,7 @@ class SPSRNetSimplified(nn.Module): self.b_proc_block_3 = RRDB(nf, gc=32) self.b_concat_decimate_4 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) self.b_proc_block_4 = RRDB(nf, gc=32) + # Upsampling self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) b_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) @@ -336,3 +340,214 @@ class SPSRNetSimplified(nn.Module): ######### return x_out_branch, x_out, x_grad + +class SPSRNetSimplifiedNoSkip(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, upscale=4): + super(SPSRNetSimplifiedNoSkip, self).__init__() + n_upscale = int(math.log(upscale, 2)) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.model_shortcut_blk = nn.Sequential(*[RRDB(nf, gc=32) for _ in range(nb)]) + self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self.feature_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + + # Grad branch + self.get_g_nopadding = ImageGradientNoPadding() + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.b_concat_decimate_1 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) + self.b_proc_block_1 = RRDB(nf, gc=32) + self.b_concat_decimate_2 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) + self.b_proc_block_2 = RRDB(nf, gc=32) + self.b_concat_decimate_3 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) + self.b_proc_block_3 = RRDB(nf, gc=32) + self.b_concat_decimate_4 = ConvGnLelu(2 * nf, nf, kernel_size=1, norm=False, activation=False, bias=False) + self.b_proc_block_4 = RRDB(nf, gc=32) + # Upsampling + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + b_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + grad_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + grad_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.branch_upsample = B.sequential(*b_upsampler, grad_hr_conv1, grad_hr_conv2) + # Conv used to output grad branch shortcut. + self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) + + # Conjoin branch. + # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. + self._branch_pretrain_concat = ConvGnLelu(nf * 2, nf, kernel_size=1, norm=False, activation=False, bias=False) + self._branch_pretrain_block = RRDB(nf * 2, gc=32) + self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + + def forward(self, x): + + x_grad = self.get_g_nopadding(x) + x = self.model_fea_conv(x) + + x_ori = x + for i in range(5): + x = self.model_shortcut_blk[i](x) + x_fea1 = x + + for i in range(5): + x = self.model_shortcut_blk[i + 5](x) + x_fea2 = x + + for i in range(5): + x = self.model_shortcut_blk[i + 10](x) + x_fea3 = x + + for i in range(5): + x = self.model_shortcut_blk[i + 15](x) + x_fea4 = x + + x = self.model_shortcut_blk[20:](x) + x = self.feature_lr_conv(x) + + # short cut + x = x_ori + x + x = self.model_upsampler(x) + x = self.feature_hr_conv1(x) + x = self.feature_hr_conv2(x) + + x_b_fea = self.b_fea_conv(x_grad) + x_cat_1 = self.b_proc_block_1(x_b_fea) + x_cat_2 = self.b_proc_block_2(x_cat_1) + x_cat_3 = self.b_proc_block_3(x_cat_2) + x_cat_4 = self.b_proc_block_4(x_cat_3) + x_cat_4 = x_cat_4 + x_b_fea + x_cat_4 = self.grad_lr_conv(x_cat_4) + + # short cut + x_branch = self.branch_upsample(x_cat_4) + x_out_branch = self.grad_branch_output_conv(x_branch) + + ######## + x_branch_d = x_branch + x__branch_pretrain_cat = torch.cat([x_branch_d, x], dim=1) + x__branch_pretrain_cat = self._branch_pretrain_block(x__branch_pretrain_cat) + x_out = self._branch_pretrain_concat(x__branch_pretrain_cat) + x_out = self._branch_pretrain_HR_conv0(x_out) + x_out = self._branch_pretrain_HR_conv1(x_out) + + ######### + return x_out_branch, x_out, x_grad + + +class SwitchedSpsr(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, upscale=4): + super(SwitchedSpsr, self).__init__() + n_upscale = int(math.log(upscale, 2)) + + # switch options + transformation_filters = nf + switch_filters = nf + switch_reductions = 3 + switch_processing_layers = 2 + trans_counts = 8 + multiplx_fn = functools.partial(ConvBasisMultiplexer, transformation_filters, switch_filters, switch_reductions, + switch_processing_layers, trans_counts) + pretransform_fn = functools.partial(ConvGnLelu, transformation_filters, transformation_filters, norm=False, bias=False, weight_init_factor=.1) + transform_fn = functools.partial(MultiConvBlock, transformation_filters, int(transformation_filters * 1.5), + transformation_filters, kernel_size=3, depth=trans_layers, + weight_init_factor=.1) + + # Feature branch + self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=trans_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=trans_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + self.feature_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + self.model_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + self.feature_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self.feature_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + + # Grad branch + self.get_g_nopadding = ImageGradientNoPadding() + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.sw_grad = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=trans_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + # Upsampling + self.grad_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + b_upsampler = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=False, bias=False) for _ in range(n_upscale)]) + grad_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + grad_hr_conv2 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False, bias=False) + self.branch_upsample = B.sequential(*b_upsampler, grad_hr_conv1, grad_hr_conv2) + # Conv used to output grad branch shortcut. + self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=False) + + # Conjoin branch. + # Note: "_branch_pretrain" is a special tag used to denote parameters that get pretrained before the rest. + self._branch_pretrain_concat = ConvGnLelu(nf * 2, nf, kernel_size=1, norm=False, activation=False, bias=False) + self._branch_pretrain_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn, + pre_transform_block=pretransform_fn, transform_block=transform_fn, + attention_norm=True, + transform_count=trans_counts, init_temp=10, + add_scalable_noise_to_transforms=True) + self._branch_pretrain_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=False) + self._branch_pretrain_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.switches = [self.sw1, self.sw2, self.sw_grad, self._branch_pretrain_sw] + self.attentions = None + self.init_temperature = 10 + self.final_temperature_step = 10000 + + def forward(self, x): + x_grad = self.get_g_nopadding(x) + x = self.model_fea_conv(x) + + x1, a1 = self.sw1(x, True) + x2, a2 = self.sw2(x1, True) + x_fea = self.feature_lr_conv(x2) + x_fea = self.model_upsampler(x_fea) + x_fea = self.feature_hr_conv1(x_fea) + x_fea = self.feature_hr_conv2(x_fea) + + x_b_fea = self.b_fea_conv(x_grad) + x_grad, a3 = self.sw_grad(x_b_fea, att_in=x1, output_attention_weights=True) + x_grad = self.grad_lr_conv(x_grad) + x_grad = self.branch_upsample(x_grad) + x_out_branch = self.grad_branch_output_conv(x_grad) + + x__branch_pretrain_cat = torch.cat([x_grad, x_fea], dim=1) + x__branch_pretrain_cat, a4 = self._branch_pretrain_sw(x__branch_pretrain_cat, True) + x_out = self._branch_pretrain_concat(x__branch_pretrain_cat) + x_out = self._branch_pretrain_HR_conv0(x_out) + x_out = self._branch_pretrain_HR_conv1(x_out) + + return x_out_branch, x_out, x_grad + + def set_temperature(self, temp): + [sw.set_temperature(temp) for sw in self.switches] + + def update_for_step(self, step, experiments_path='.'): + if self.attentions: + temp = max(1, 1 + self.init_temperature * + (self.final_temperature_step - step) / self.final_temperature_step) + self.set_temperature(temp) + if step % 50 == 0: + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + + def get_debug_values(self, step): + temp = self.switches[0].switch.temperature + mean_hists = [compute_attention_specificity(att, 2) for att in self.attentions] + means = [i[0] for i in mean_hists] + hists = [i[1].clone().detach().cpu().flatten() for i in mean_hists] + val = {"switch_temperature": temp} + for i in range(len(means)): + val["switch_%i_specificity" % (i,)] = means[i] + val["switch_%i_histogram" % (i,)] = hists[i] + return val diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 2c3a6eb4..cc681a9f 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -134,7 +134,10 @@ class ConfigurableSwitchComputer(nn.Module): # depending on its needs. self.psc_scale = nn.Parameter(torch.full((1,), float(.1))) - def forward(self, x, output_attention_weights=False, fixed_scale=1): + def forward(self, x, output_attention_weights=False, att_in=None, fixed_scale=1): + if att_in is None: + att_in = x + identity = x if self.add_noise: rand_feature = torch.randn_like(x) * self.noise_scale @@ -143,7 +146,7 @@ class ConfigurableSwitchComputer(nn.Module): if self.pre_transform: x = self.pre_transform(x) xformed = [t.forward(x) for t in self.transforms] - m = self.multiplexer(identity) + m = self.multiplexer(att_in) outputs, attention = self.switch(xformed, m, True) diff --git a/codes/models/networks.py b/codes/models/networks.py index 5fccc92b..625314ab 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -108,6 +108,9 @@ def define_G(opt, net_key='network_G'): elif which_model == 'spsr_net_improved': netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) + elif which_model == 'spsr_net_improved_noskip': + netG = spsr.SPSRNetSimplifiedNoSkip(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], + nb=opt_net['nb'], upscale=opt_net['scale']) # image corruption elif which_model == 'HighToLowResNet': diff --git a/codes/train.py b/codes/train.py index 1de1fda4..e2df459e 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_rrdb.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_rrdb_noskip.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) diff --git a/sandbox.py b/sandbox.py new file mode 100644 index 00000000..18825c15 --- /dev/null +++ b/sandbox.py @@ -0,0 +1,22 @@ +import torch +import torchvision +from PIL import Image + +def load_img(path): + im = Image.open(path) + return torchvision.transforms.ToTensor()(im) + +def save_img(t, path): + torchvision.utils.save_image(t, path) + +img = load_img("me.png") +# add zeros to the imaginary component +img = torch.stack([img, torch.zeros_like(img)], dim=-1) +fft = torch.fft(img, signal_ndim=2) +fft_d = torch.zeros_like(fft) +for i in range(-5, 5): + diag = torch.diagonal(fft, offset=i, dim1=1, dim2=2) + diag_em = torch.diag_embed(diag, offset=i, dim1=1, dim2=2) + fft_d += diag_em +resamp_img = torch.ifft(fft_d, signal_ndim=2)[:, :, :, 0] +save_img(resamp_img, "resampled.png") \ No newline at end of file