Add PANet arch

This commit is contained in:
James Betker 2020-10-12 10:20:55 -06:00
parent 7cbf4fa665
commit 3409d88a1c
6 changed files with 375 additions and 3 deletions

View File

@ -456,17 +456,23 @@ class ConjoinBlock(nn.Module):
# Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch. # Designed explicitly to join a mainline trunk with reference data. Implemented as a residual branch.
class ReferenceJoinBlock(nn.Module): class ReferenceJoinBlock(nn.Module):
def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False, kernel_size=3, depth=3): def __init__(self, nf, residual_weight_init_factor=1, block=ConvGnLelu, final_norm=False, kernel_size=3, depth=3, join=True):
super(ReferenceJoinBlock, self).__init__() super(ReferenceJoinBlock, self).__init__()
self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=kernel_size, depth=depth, self.branch = MultiConvBlock(nf * 2, nf + nf // 2, nf, kernel_size=kernel_size, depth=depth,
scale_init=residual_weight_init_factor, norm=False, scale_init=residual_weight_init_factor, norm=False,
weight_init_factor=residual_weight_init_factor) weight_init_factor=residual_weight_init_factor)
self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True) if join:
self.join_conv = block(nf, nf, kernel_size=kernel_size, norm=final_norm, bias=False, activation=True)
else:
self.join_conv = None
def forward(self, x, ref): def forward(self, x, ref):
joined = torch.cat([x, ref], dim=1) joined = torch.cat([x, ref], dim=1)
branch = self.branch(joined) branch = self.branch(joined)
return self.join_conv(x + branch), torch.std(branch) if self.join_conv is not None:
return self.join_conv(x + branch), torch.std(branch)
else:
return x + branch, torch.std(branch)
# Basic convolutional upsampling block that uses interpolate. # Basic convolutional upsampling block that uses interpolate.

View File

@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import utils as vutils
import models.archs.panet.common as common
from models.archs.panet.tools import extract_image_patches, \
reduce_mean, reduce_sum, same_padding
from utils.util import checkpoint
class PyramidAttention(nn.Module):
def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True,
conv=common.default_conv):
super(PyramidAttention, self).__init__()
self.ksize = ksize
self.stride = stride
self.res_scale = res_scale
self.softmax_scale = softmax_scale
self.scale = [1 - i / 10 for i in range(level)]
self.average = average
escape_NaN = torch.FloatTensor([1e-4])
self.register_buffer('escape_NaN', escape_NaN)
self.conv_match_L_base = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU())
self.conv_match = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU())
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
def forward(self, input):
res = input
# theta
match_base = self.conv_match_L_base(input)
shape_base = list(res.size())
input_groups = torch.split(match_base, 1, dim=0)
# patch size for matching
kernel = self.ksize
# raw_w is for reconstruction
raw_w = []
# w is for matching
w = []
# build feature pyramid
for i in range(len(self.scale)):
ref = input
if self.scale[i] != 1:
ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic')
# feature transformation function f
base = self.conv_assembly(ref)
shape_input = base.shape
# sampling
raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
strides=[self.stride, self.stride],
rates=[1, 1],
padding='same') # [N, C*k*k, L]
raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
raw_w.append(raw_w_i_groups)
# feature transformation function g
ref_i = self.conv_match(ref)
shape_ref = ref_i.shape
# sampling
w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding='same')
w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
w_i_groups = torch.split(w_i, 1, dim=0)
w.append(w_i_groups)
y = []
for idx, xi in enumerate(input_groups):
# group in a filter
wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))], dim=0) # [L, C, k, k]
# normalize
max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
axis=[1, 2, 3],
keepdim=True)),
self.escape_NaN)
wi_normed = wi / max_wi
# matching
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
yi = yi.view(1, wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32)
# softmax matching score
yi = F.softmax(yi * self.softmax_scale, dim=1)
if self.average == False:
yi = (yi == yi.max(dim=1, keepdim=True)[0]).float()
# deconv for patch pasting
raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))], dim=0)
yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride, padding=1) / 4.
y.append(yi)
y = torch.cat(y, dim=0) + res * self.res_scale # back to the mini-batch
return y

View File

@ -0,0 +1,87 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2),stride=stride, bias=bias)
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
for p in self.parameters():
p.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
bn=False, act=nn.PReLU()):
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
if bn:
m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.PReLU(), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)

View File

@ -0,0 +1,91 @@
from models.archs.panet import common
from models.archs.panet import attention
import torch.nn as nn
from utils.util import checkpoint
def make_model(args, parent=False):
return PANET(args)
class PANET(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(PANET, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
self.msa = attention.PyramidAttention()
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale
) for _ in range(n_resblocks // 2)
]
m_body.append(self.msa)
for i in range(n_resblocks // 2):
m_body.append(common.ResBlock(conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale))
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
# m_tail = [
# common.Upsampler(conv, scale, n_feats, act=False),
# conv(n_feats, args.n_colors, kernel_size)
# ]
m_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.ModuleList(m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
# x = self.sub_mean(x)
x = self.head(x)
res = x
for b in self.body:
if b == self.msa:
if __name__ == '__main__':
res = self.msa(res)
else:
res = checkpoint(b, res)
res += x
x = checkpoint(self.tail, res)
# x = self.add_mean(x)
return x,
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))

View File

@ -0,0 +1,84 @@
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
def normalize(x):
return x.mul_(2).add_(-1)
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows)
padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def reduce_std(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.std(x, dim=i, keepdim=keepdim)
return x
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x

View File

@ -12,6 +12,7 @@ import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.SPSR_arch as spsr import models.archs.SPSR_arch as spsr
import models.archs.StructuredSwitchedGenerator as ssg import models.archs.StructuredSwitchedGenerator as ssg
import models.archs.rcan as rcan import models.archs.rcan as rcan
import models.archs.panet.panet as panet
from collections import OrderedDict from collections import OrderedDict
import torchvision import torchvision
import functools import functools
@ -48,6 +49,12 @@ def define_G(opt, net_key='network_G', scale=None):
opt_net['n_colors'] = 3 opt_net['n_colors'] = 3
args_obj = munchify(opt_net) args_obj = munchify(opt_net)
netG = rcan.RCAN(args_obj) netG = rcan.RCAN(args_obj)
elif which_model == 'panet':
#args: n_resblocks, res_scale, scale, n_feats
opt_net['rgb_range'] = 255
opt_net['n_colors'] = 3
args_obj = munchify(opt_net)
netG = panet.PANET(args_obj)
elif which_model == "ConfigurableSwitchedResidualGenerator2": elif which_model == "ConfigurableSwitchedResidualGenerator2":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'], switch_reductions=opt_net['switch_reductions'],