forked from mrq/DL-Art-School
Add RCAN
This commit is contained in:
parent
4d29b7729e
commit
aeaf185314
221
codes/models/archs/rcan.py
Normal file
221
codes/models/archs/rcan.py
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
||||||
|
return nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size,
|
||||||
|
padding=(kernel_size//2), bias=bias)
|
||||||
|
|
||||||
|
class MeanShift(nn.Conv2d):
|
||||||
|
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
|
||||||
|
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
||||||
|
std = torch.Tensor(rgb_std)
|
||||||
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
||||||
|
self.weight.data.div_(std.view(3, 1, 1, 1))
|
||||||
|
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
||||||
|
self.bias.data.div_(std)
|
||||||
|
self.requires_grad = False
|
||||||
|
|
||||||
|
class BasicBlock(nn.Sequential):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
|
||||||
|
bn=True, act=nn.ReLU(True)):
|
||||||
|
|
||||||
|
m = [nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size,
|
||||||
|
padding=(kernel_size//2), stride=stride, bias=bias)
|
||||||
|
]
|
||||||
|
if bn: m.append(nn.BatchNorm2d(out_channels))
|
||||||
|
if act is not None: m.append(act)
|
||||||
|
super(BasicBlock, self).__init__(*m)
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, conv, n_feat, kernel_size,
|
||||||
|
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
||||||
|
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
m = []
|
||||||
|
for i in range(2):
|
||||||
|
m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
||||||
|
if bn: m.append(nn.BatchNorm2d(n_feat))
|
||||||
|
if i == 0: m.append(act)
|
||||||
|
|
||||||
|
self.body = nn.Sequential(*m)
|
||||||
|
self.res_scale = res_scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.body(x).mul(self.res_scale)
|
||||||
|
res += x
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
class Upsampler(nn.Sequential):
|
||||||
|
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
|
||||||
|
|
||||||
|
m = []
|
||||||
|
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
||||||
|
for _ in range(int(math.log(scale, 2))):
|
||||||
|
m.append(conv(n_feat, 4 * n_feat, 3, bias))
|
||||||
|
m.append(nn.PixelShuffle(2))
|
||||||
|
if bn: m.append(nn.BatchNorm2d(n_feat))
|
||||||
|
if act: m.append(act())
|
||||||
|
elif scale == 3:
|
||||||
|
m.append(conv(n_feat, 9 * n_feat, 3, bias))
|
||||||
|
m.append(nn.PixelShuffle(3))
|
||||||
|
if bn: m.append(nn.BatchNorm2d(n_feat))
|
||||||
|
if act: m.append(act())
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
super(Upsampler, self).__init__(*m)
|
||||||
|
|
||||||
|
def make_model(args, parent=False):
|
||||||
|
return RCAN(args)
|
||||||
|
|
||||||
|
|
||||||
|
## Channel Attention (CA) Layer
|
||||||
|
class CALayer(nn.Module):
|
||||||
|
def __init__(self, channel, reduction=16):
|
||||||
|
super(CALayer, self).__init__()
|
||||||
|
# global average pooling: feature --> point
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
# feature channel downscale and upscale --> channel weight
|
||||||
|
self.conv_du = nn.Sequential(
|
||||||
|
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.avg_pool(x)
|
||||||
|
y = self.conv_du(y)
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
|
## Residual Channel Attention Block (RCAB)
|
||||||
|
class RCAB(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, conv, n_feat, kernel_size, reduction,
|
||||||
|
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
||||||
|
|
||||||
|
super(RCAB, self).__init__()
|
||||||
|
modules_body = []
|
||||||
|
for i in range(2):
|
||||||
|
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
|
||||||
|
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
|
||||||
|
if i == 0: modules_body.append(act)
|
||||||
|
modules_body.append(CALayer(n_feat, reduction))
|
||||||
|
self.body = nn.Sequential(*modules_body)
|
||||||
|
self.res_scale = res_scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.body(x)
|
||||||
|
# res = self.body(x).mul(self.res_scale)
|
||||||
|
res += x
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
## Residual Group (RG)
|
||||||
|
class ResidualGroup(nn.Module):
|
||||||
|
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
|
||||||
|
super(ResidualGroup, self).__init__()
|
||||||
|
modules_body = []
|
||||||
|
modules_body = [
|
||||||
|
RCAB(
|
||||||
|
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
|
||||||
|
for _ in range(n_resblocks)]
|
||||||
|
modules_body.append(conv(n_feat, n_feat, kernel_size))
|
||||||
|
self.body = nn.Sequential(*modules_body)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.body(x)
|
||||||
|
res += x
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
## Residual Channel Attention Network (RCAN)
|
||||||
|
class RCAN(nn.Module):
|
||||||
|
def __init__(self, args, conv=default_conv):
|
||||||
|
super(RCAN, self).__init__()
|
||||||
|
|
||||||
|
n_resgroups = args.n_resgroups
|
||||||
|
n_resblocks = args.n_resblocks
|
||||||
|
n_feats = args.n_feats
|
||||||
|
kernel_size = 3
|
||||||
|
reduction = args.reduction
|
||||||
|
scale = args.scale
|
||||||
|
act = nn.ReLU(True)
|
||||||
|
|
||||||
|
# RGB mean for DIV2K
|
||||||
|
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||||
|
rgb_std = (1.0, 1.0, 1.0)
|
||||||
|
self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std)
|
||||||
|
|
||||||
|
# define head module
|
||||||
|
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
|
||||||
|
|
||||||
|
# define body module
|
||||||
|
modules_body = [
|
||||||
|
ResidualGroup(
|
||||||
|
conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
|
||||||
|
for _ in range(n_resgroups)]
|
||||||
|
|
||||||
|
modules_body.append(conv(n_feats, n_feats, kernel_size))
|
||||||
|
|
||||||
|
# define tail module
|
||||||
|
modules_tail = [
|
||||||
|
Upsampler(conv, scale, n_feats, act=False),
|
||||||
|
conv(n_feats, args.n_colors, kernel_size)]
|
||||||
|
|
||||||
|
self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
|
||||||
|
|
||||||
|
self.head = nn.Sequential(*modules_head)
|
||||||
|
self.body = nn.Sequential(*modules_body)
|
||||||
|
self.tail = nn.Sequential(*modules_tail)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.sub_mean(x)
|
||||||
|
x = self.head(x)
|
||||||
|
|
||||||
|
res = self.body(x)
|
||||||
|
res += x
|
||||||
|
|
||||||
|
x = self.tail(res)
|
||||||
|
x = self.add_mean(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=False):
|
||||||
|
own_state = self.state_dict()
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in own_state:
|
||||||
|
if isinstance(param, nn.Parameter):
|
||||||
|
param = param.data
|
||||||
|
try:
|
||||||
|
own_state[name].copy_(param)
|
||||||
|
except Exception:
|
||||||
|
if name.find('tail') >= 0:
|
||||||
|
print('Replace pre-trained upsampler to new one...')
|
||||||
|
else:
|
||||||
|
raise RuntimeError('While copying the parameter named {}, '
|
||||||
|
'whose dimensions in the model are {} and '
|
||||||
|
'whose dimensions in the checkpoint are {}.'
|
||||||
|
.format(name, own_state[name].size(), param.size()))
|
||||||
|
elif strict:
|
||||||
|
if name.find('tail') == -1:
|
||||||
|
raise KeyError('unexpected key "{}" in state_dict'
|
||||||
|
.format(name))
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
missing = set(own_state.keys()) - set(state_dict.keys())
|
||||||
|
if len(missing) > 0:
|
||||||
|
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
|
|
@ -10,6 +10,7 @@ import models.archs.feature_arch as feature_arch
|
||||||
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
||||||
import models.archs.SPSR_arch as spsr
|
import models.archs.SPSR_arch as spsr
|
||||||
import models.archs.StructuredSwitchedGenerator as ssg
|
import models.archs.StructuredSwitchedGenerator as ssg
|
||||||
|
import models.archs.rcan as rcan
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -37,6 +38,12 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'] if 'scale' in opt_net.keys() else gen_scale,
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'] if 'scale' in opt_net.keys() else gen_scale,
|
||||||
initial_stride=initial_stride)
|
initial_stride=initial_stride)
|
||||||
|
elif which_model == 'rcan':
|
||||||
|
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
||||||
|
opt_net['rgb_range'] = 255
|
||||||
|
opt_net['n_colors'] = 3
|
||||||
|
args_obj = munchify(opt_net)
|
||||||
|
netG = rcan.RCAN(args_obj)
|
||||||
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
||||||
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
||||||
switch_reductions=opt_net['switch_reductions'],
|
switch_reductions=opt_net['switch_reductions'],
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import models.archs.SRG1_arch as srg1
|
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||||
import models.archs.NestedSwitchGenerator as nsg
|
|
||||||
import models.archs.discriminator_vgg_arch as disc
|
import models.archs.discriminator_vgg_arch as disc
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user