diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 94f430ac..6fefa9f4 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -456,17 +456,23 @@ class ConjoinBlock(nn.Module): # Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch. class ReferenceJoinBlock(nn.Module): - def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False, kernel_size=3, depth=3): + def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False, kernel_size=3, depth=3, join=True): super(ReferenceJoinBlock, self).__init__() self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=kernel_size, depth=depth, scale_init=residual_weight_init_factor, norm=False, weight_init_factor=residual_weight_init_factor) - self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True) + if join: + self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True) + else: + self.join_conv = None def forward(self, x, ref): joined = torch.cat([x, ref], dim=1) branch = self.branch(joined) - return self.join_conv(x + branch), torch.std(branch) + if self.join_conv is not None: + return self.join_conv(x + branch), torch.std(branch) + else: + return x + branch, torch.std(branch) # Basic convolutional upsampling block that uses interpolate. diff --git a/codes/models/archs/panet/attention.py b/codes/models/archs/panet/attention.py new file mode 100644 index 00000000..4ce2c20a --- /dev/null +++ b/codes/models/archs/panet/attention.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from torchvision import utils as vutils +import models.archs.panet.common as common +from models.archs.panet.tools import extract_image_patches, \ + reduce_mean, reduce_sum, same_padding +from utils.util import checkpoint + + +class PyramidAttention(nn.Module): + def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, + conv=common.default_conv): + super(PyramidAttention, self).__init__() + self.ksize = ksize + self.stride = stride + self.res_scale = res_scale + self.softmax_scale = softmax_scale + self.scale = [1 - i / 10 for i in range(level)] + self.average = average + escape_NaN = torch.FloatTensor([1e-4]) + self.register_buffer('escape_NaN', escape_NaN) + self.conv_match_L_base = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) + self.conv_match = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) + self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU()) + + def forward(self, input): + res = input + # theta + match_base = self.conv_match_L_base(input) + shape_base = list(res.size()) + input_groups = torch.split(match_base, 1, dim=0) + # patch size for matching + kernel = self.ksize + # raw_w is for reconstruction + raw_w = [] + # w is for matching + w = [] + # build feature pyramid + for i in range(len(self.scale)): + ref = input + if self.scale[i] != 1: + ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic') + # feature transformation function f + base = self.conv_assembly(ref) + shape_input = base.shape + # sampling + raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel], + strides=[self.stride, self.stride], + rates=[1, 1], + padding='same') # [N, C*k*k, L] + raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1) + raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k] + raw_w_i_groups = torch.split(raw_w_i, 1, dim=0) + raw_w.append(raw_w_i_groups) + + # feature transformation function g + ref_i = self.conv_match(ref) + shape_ref = ref_i.shape + # sampling + w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize], + strides=[self.stride, self.stride], + rates=[1, 1], + padding='same') + w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1) + w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k] + w_i_groups = torch.split(w_i, 1, dim=0) + w.append(w_i_groups) + + y = [] + for idx, xi in enumerate(input_groups): + # group in a filter + wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))], dim=0) # [L, C, k, k] + # normalize + max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2), + axis=[1, 2, 3], + keepdim=True)), + self.escape_NaN) + wi_normed = wi / max_wi + # matching + xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W + yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3] + yi = yi.view(1, wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32) + # softmax matching score + yi = F.softmax(yi * self.softmax_scale, dim=1) + + if self.average == False: + yi = (yi == yi.max(dim=1, keepdim=True)[0]).float() + + # deconv for patch pasting + raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))], dim=0) + yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride, padding=1) / 4. + y.append(yi) + + y = torch.cat(y, dim=0) + res * self.res_scale # back to the mini-batch + return y \ No newline at end of file diff --git a/codes/models/archs/panet/common.py b/codes/models/archs/panet/common.py new file mode 100644 index 00000000..b67e6a74 --- /dev/null +++ b/codes/models/archs/panet/common.py @@ -0,0 +1,87 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2),stride=stride, bias=bias) + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range, + rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std + for p in self.parameters(): + p.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, + bn=False, act=nn.PReLU()): + + m = [conv(in_channels, out_channels, kernel_size, bias=bias)] + if bn: + m.append(nn.BatchNorm2d(out_channels)) + if act is not None: + m.append(act) + + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.PReLU(), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if i == 0: + m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: + m.append(nn.BatchNorm2d(n_feats)) + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) diff --git a/codes/models/archs/panet/panet.py b/codes/models/archs/panet/panet.py new file mode 100644 index 00000000..9e226f7e --- /dev/null +++ b/codes/models/archs/panet/panet.py @@ -0,0 +1,91 @@ +from models.archs.panet import common +from models.archs.panet import attention +import torch.nn as nn +from utils.util import checkpoint + + +def make_model(args, parent=False): + return PANET(args) + + +class PANET(nn.Module): + def __init__(self, args, conv=common.default_conv): + super(PANET, self).__init__() + + n_resblocks = args.n_resblocks + n_feats = args.n_feats + kernel_size = 3 + scale = args.scale[0] + + rgb_mean = (0.4488, 0.4371, 0.4040) + rgb_std = (1.0, 1.0, 1.0) + self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) + self.msa = attention.PyramidAttention() + # define head module + m_head = [conv(args.n_colors, n_feats, kernel_size)] + + # define body module + m_body = [ + common.ResBlock( + conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale + ) for _ in range(n_resblocks // 2) + ] + m_body.append(self.msa) + for i in range(n_resblocks // 2): + m_body.append(common.ResBlock(conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale)) + + m_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + # m_tail = [ + # common.Upsampler(conv, scale, n_feats, act=False), + # conv(n_feats, args.n_colors, kernel_size) + # ] + m_tail = [ + common.Upsampler(conv, scale, n_feats, act=False), + conv(n_feats, args.n_colors, kernel_size) + ] + + self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*m_head) + self.body = nn.ModuleList(m_body) + self.tail = nn.Sequential(*m_tail) + + def forward(self, x): + # x = self.sub_mean(x) + x = self.head(x) + + res = x + for b in self.body: + if b == self.msa: + if __name__ == '__main__': + res = self.msa(res) + else: + res = checkpoint(b, res) + + res += x + + x = checkpoint(self.tail, res) + # x = self.add_mean(x) + + return x, + + def load_state_dict(self, state_dict, strict=True): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') == -1: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) diff --git a/codes/models/archs/panet/tools.py b/codes/models/archs/panet/tools.py new file mode 100644 index 00000000..47e7f971 --- /dev/null +++ b/codes/models/archs/panet/tools.py @@ -0,0 +1,84 @@ +import os +import torch +import numpy as np +from PIL import Image + +import torch.nn.functional as F + + +def normalize(x): + return x.mul_(2).add_(-1) + + +def same_padding(images, ksizes, strides, rates): + assert len(images.size()) == 4 + batch_size, channel, rows, cols = images.size() + out_rows = (rows + strides[0] - 1) // strides[0] + out_cols = (cols + strides[1] - 1) // strides[1] + effective_k_row = (ksizes[0] - 1) * rates[0] + 1 + effective_k_col = (ksizes[1] - 1) * rates[1] + 1 + padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows) + padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols) + # Pad the input + padding_top = int(padding_rows / 2.) + padding_left = int(padding_cols / 2.) + padding_bottom = padding_rows - padding_top + padding_right = padding_cols - padding_left + paddings = (padding_left, padding_right, padding_top, padding_bottom) + images = torch.nn.ZeroPad2d(paddings)(images) + return images + + +def extract_image_patches(images, ksizes, strides, rates, padding='same'): + """ + Extract patches from images and put them in the C output dimension. + :param padding: + :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape + :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for + each dimension of images + :param strides: [stride_rows, stride_cols] + :param rates: [dilation_rows, dilation_cols] + :return: A Tensor + """ + assert len(images.size()) == 4 + assert padding in ['same', 'valid'] + batch_size, channel, height, width = images.size() + + if padding == 'same': + images = same_padding(images, ksizes, strides, rates) + elif padding == 'valid': + pass + else: + raise NotImplementedError('Unsupported padding type: {}.\ + Only "same" or "valid" are supported.'.format(padding)) + + unfold = torch.nn.Unfold(kernel_size=ksizes, + dilation=rates, + padding=0, + stride=strides) + patches = unfold(images) + return patches # [N, C*k*k, L], L is the total number of such blocks + + +def reduce_mean(x, axis=None, keepdim=False): + if not axis: + axis = range(len(x.shape)) + for i in sorted(axis, reverse=True): + x = torch.mean(x, dim=i, keepdim=keepdim) + return x + + +def reduce_std(x, axis=None, keepdim=False): + if not axis: + axis = range(len(x.shape)) + for i in sorted(axis, reverse=True): + x = torch.std(x, dim=i, keepdim=keepdim) + return x + + +def reduce_sum(x, axis=None, keepdim=False): + if not axis: + axis = range(len(x.shape)) + for i in sorted(axis, reverse=True): + x = torch.sum(x, dim=i, keepdim=keepdim) + return x diff --git a/codes/models/networks.py b/codes/models/networks.py index 6e34ac6b..d3e58fdc 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -12,6 +12,7 @@ import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.SPSR_arch as spsr import models.archs.StructuredSwitchedGenerator as ssg import models.archs.rcan as rcan +import models.archs.panet.panet as panet from collections import OrderedDict import torchvision import functools @@ -48,6 +49,12 @@ def define_G(opt, net_key='network_G', scale=None): opt_net['n_colors'] = 3 args_obj = munchify(opt_net) netG = rcan.RCAN(args_obj) + elif which_model == 'panet': + #args: n_resblocks, res_scale, scale, n_feats + opt_net['rgb_range'] = 255 + opt_net['n_colors'] = 3 + args_obj = munchify(opt_net) + netG = panet.PANET(args_obj) elif which_model == "ConfigurableSwitchedResidualGenerator2": netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'],