diff --git a/codes/models/archs/mdcn/common.py b/codes/models/archs/mdcn/common.py new file mode 100644 index 00000000..f72b82d8 --- /dev/null +++ b/codes/models/archs/mdcn/common.py @@ -0,0 +1,86 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.autograd import Variable + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias) + +class MeanShift(nn.Conv2d): + def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(rgb_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) + self.weight.data.div_(std.view(3, 1, 1, 1)) + self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) + self.bias.data.div_(std) + self.requires_grad = False + +class BasicBlock(nn.Sequential): + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, bias=False, + bn=True, act=nn.ReLU(True)): + + m = [nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), stride=stride, bias=bias) + ] + if bn: m.append(nn.BatchNorm2d(out_channels)) + if act is not None: m.append(act) + super(BasicBlock, self).__init__(*m) + +class ResBlock(nn.Module): + def __init__( + self, conv, n_feats, kernel_size, + bias=True, bn=False, act=nn.ReLU(True), res_scale=1): + + super(ResBlock, self).__init__() + m = [] + for i in range(2): + m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) + if bn: m.append(nn.BatchNorm2d(n_feats)) + if i == 0: m.append(act) + + self.body = nn.Sequential(*m) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x).mul(self.res_scale) + res += x + + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=False): + + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: m.append(nn.BatchNorm2d(n_feats)) + + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: m.append(nn.BatchNorm2d(n_feats)) + + if act == 'relu': + m.append(nn.ReLU(True)) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) diff --git a/codes/models/archs/mdcn/mdcn.py b/codes/models/archs/mdcn/mdcn.py new file mode 100644 index 00000000..a2f7ad7a --- /dev/null +++ b/codes/models/archs/mdcn/mdcn.py @@ -0,0 +1,152 @@ +from models.archs.mdcn import common +import torch +import torch.nn as nn + +from utils.util import checkpoint + + +def make_model(args, parent=False): + return MDCN(args) + + +class MDCB(nn.Module): + def __init__(self, conv=common.default_conv): + super(MDCB, self).__init__() + + n_feats = 128 + d_feats = 96 + kernel_size_1 = 3 + kernel_size_2 = 5 + act = nn.ReLU(True) + + self.conv_3_1 = conv(n_feats, n_feats, kernel_size_1) + self.conv_3_2 = conv(d_feats, d_feats, kernel_size_1) + self.conv_5_1 = conv(n_feats, n_feats, kernel_size_2) + self.conv_5_2 = conv(d_feats, d_feats, kernel_size_2) + self.confusion_3 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True) + self.confusion_5 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True) + self.confusion_bottle = nn.Conv2d(n_feats * 3 + d_feats * 2, n_feats, 1, padding=0, bias=True) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + input_1 = x + output_3_1 = self.relu(self.conv_3_1(input_1)) + output_5_1 = self.relu(self.conv_5_1(input_1)) + input_2 = torch.cat([input_1, output_3_1, output_5_1], 1) + input_2_3 = self.confusion_3(input_2) + input_2_5 = self.confusion_5(input_2) + + output_3_2 = self.relu(self.conv_3_2(input_2_3)) + output_5_2 = self.relu(self.conv_5_2(input_2_5)) + input_3 = torch.cat([input_1, output_3_1, output_5_1, output_3_2, output_5_2], 1) + output = self.confusion_bottle(input_3) + output += x + return output + + +class CALayer(nn.Module): + def __init__(self, n_feats, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(n_feats, n_feats // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(n_feats // reduction, n_feats, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +class DB(nn.Module): + def __init__(self, conv=common.default_conv): + super(DB, self).__init__() + + n_feats = 128 + d_feats = 96 + n_blocks = 12 + + self.fushion_down = nn.Conv2d(n_feats * (n_blocks - 1), d_feats, 1, padding=0, bias=True) + self.channel_attention = CALayer(d_feats) + self.fushion_up = nn.Conv2d(d_feats, n_feats, 1, padding=0, bias=True) + + def forward(self, x): + x = self.fushion_down(x) + x = self.channel_attention(x) + x = self.fushion_up(x) + return x + + +class MDCN(nn.Module): + def __init__(self, args, conv=common.default_conv): + super(MDCN, self).__init__() + n_feats = 128 + kernel_size = 3 + self.scale_idx = 0 + act = nn.ReLU(True) + + n_blocks = 12 + self.n_blocks = n_blocks + + # RGB mean for DIV2K + rgb_mean = (0.4488, 0.4371, 0.4040) + rgb_std = (1.0, 1.0, 1.0) + self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) + + # define head module + modules_head = [conv(args.n_colors, n_feats, kernel_size)] + + # define body module + modules_body = nn.ModuleList() + for i in range(n_blocks): + modules_body.append(MDCB()) + + # define distillation module + modules_dist = nn.ModuleList() + modules_dist.append(DB()) + + modules_transform = [conv(n_feats, n_feats, kernel_size)] + self.upsample = nn.ModuleList([ + common.Upsampler( + conv, s, n_feats, act=True + ) for s in args.scale + ]) + modules_rebult = [conv(n_feats, args.n_colors, kernel_size)] + + self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.dist = nn.Sequential(*modules_dist) + self.transform = nn.Sequential(*modules_transform) + self.rebult = nn.Sequential(*modules_rebult) + + def forward(self, x): + x = self.sub_mean(x) + x = checkpoint(self.head, x) + front = x + + MDCB_out = [] + for i in range(self.n_blocks): + x = checkpoint(self.body[i], x) + if i != (self.n_blocks - 1): + MDCB_out.append(x) + + hierarchical = torch.cat(MDCB_out, 1) + hierarchical = checkpoint(self.dist, hierarchical) + + mix = front + hierarchical + x + + out = checkpoint(self.transform, mix) + out = self.upsample[self.scale_idx](out) + out = checkpoint(self.rebult, out) + out = self.add_mean(out) + return out + + def set_scale(self, scale_idx): + self.scale_idx = scale_idx \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index c51a0336..3714ace8 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -173,6 +173,10 @@ def define_G(opt, opt_net, scale=None): netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'], initial_conv_stride=opt_net['initial_stride']) + elif which_model == 'mdcn': + from models.archs.mdcn.mdcn import MDCN + args = munchify({'scale': opt_net['scale'], 'n_colors': 3, 'rgb_range': 1.0}) + netG = MDCN(args) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) return netG