diff --git a/codes/models/archs/rcan.py b/codes/models/archs/rcan.py new file mode 100644 index 00000000..71d6955c --- /dev/null +++ b/codes/models/archs/rcan.py @@ -0,0 +1,221 @@ +import torch.nn as nn +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from torch.autograd import Variable + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) + self.weight.data.div_(std.view(3, 1, 1, 1)) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) + self.bias.data.div_(std) + self.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), stride=stride, bias=bias) + ] + if bn: m.append(nn.BatchNorm2d(out_channels)) + if act is not None: m.append(act) + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: m.append(nn.BatchNorm2d(n_feat)) + if i == 0: m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feat, 4 * n_feat, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: m.append(nn.BatchNorm2d(n_feat)) + if act: m.append(act()) + elif scale == 3: + m.append(conv(n_feat, 9 * n_feat, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: m.append(nn.BatchNorm2d(n_feat)) + if act: m.append(act()) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + +def make_model(args, parent=False): + return RCAN(args) + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__( + self, conv, n_feat, kernel_size, reduction, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2d(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + # res = self.body(x).mul(self.res_scale) + res += x + return res + + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +## Residual Channel Attention Network (RCAN) +class RCAN(nn.Module): + def __init__(self, args, conv=default_conv): + super(RCAN, self).__init__() + + n_resgroups = args.n_resgroups + n_resblocks = args.n_resblocks + n_feats = args.n_feats + kernel_size = 3 + reduction = args.reduction + scale = args.scale + act = nn.ReLU(True) + + # RGB mean for DIV2K + rgb_mean = (0.4488, 0.4371, 0.4040) + rgb_std = (1.0, 1.0, 1.0) + self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std) + + # define head module + modules_head = [conv(args.n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler(conv, scale, n_feats, act=False), + conv(n_feats, args.n_colors, kernel_size)] + + self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.tail = nn.Sequential(*modules_tail) + + def forward(self, x): + x = self.sub_mean(x) + x = self.head(x) + + res = self.body(x) + res += x + + x = self.tail(res) + x = self.add_mean(x) + + return x + + def load_state_dict(self, state_dict, strict=False): + own_state = self.state_dict() + for name, param in state_dict.items(): + if name in own_state: + if isinstance(param, nn.Parameter): + param = param.data + try: + own_state[name].copy_(param) + except Exception: + if name.find('tail') >= 0: + print('Replace pre-trained upsampler to new one...') + else: + raise RuntimeError('While copying the parameter named {}, ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}.' + .format(name, own_state[name].size(), param.size())) + elif strict: + if name.find('tail') == -1: + raise KeyError('unexpected key "{}" in state_dict' + .format(name)) + + if strict: + missing = set(own_state.keys()) - set(state_dict.keys()) + if len(missing) > 0: + raise KeyError('missing keys in state_dict: "{}"'.format(missing)) \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 858ed938..f06c87c0 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -10,6 +10,7 @@ import models.archs.feature_arch as feature_arch import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch import models.archs.SPSR_arch as spsr import models.archs.StructuredSwitchedGenerator as ssg +import models.archs.rcan as rcan from collections import OrderedDict logger = logging.getLogger('base') @@ -37,6 +38,12 @@ def define_G(opt, net_key='network_G', scale=None): netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'] if 'scale' in opt_net.keys() else gen_scale, initial_stride=initial_stride) + elif which_model == 'rcan': + #args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats + opt_net['rgb_range'] = 255 + opt_net['n_colors'] = 3 + args_obj = munchify(opt_net) + netG = rcan.RCAN(args_obj) elif which_model == "ConfigurableSwitchedResidualGenerator2": netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'], switch_reductions=opt_net['switch_reductions'], diff --git a/codes/utils/numeric_stability.py b/codes/utils/numeric_stability.py index 588555b6..2fc0fb5f 100644 --- a/codes/utils/numeric_stability.py +++ b/codes/utils/numeric_stability.py @@ -1,8 +1,6 @@ import torch from torch import nn -import models.archs.SRG1_arch as srg1 import models.archs.SwitchedResidualGenerator_arch as srg -import models.archs.NestedSwitchGenerator as nsg import models.archs.discriminator_vgg_arch as disc import functools