Add mdcn
This commit is contained in:
parent
1e0f69e34b
commit
a1c8300052
86
codes/models/archs/mdcn/common.py
Normal file
86
codes/models/archs/mdcn/common.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
||||||
|
return nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size,
|
||||||
|
padding=(kernel_size//2), bias=bias)
|
||||||
|
|
||||||
|
class MeanShift(nn.Conv2d):
|
||||||
|
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
|
||||||
|
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
||||||
|
std = torch.Tensor(rgb_std)
|
||||||
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
||||||
|
self.weight.data.div_(std.view(3, 1, 1, 1))
|
||||||
|
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
||||||
|
self.bias.data.div_(std)
|
||||||
|
self.requires_grad = False
|
||||||
|
|
||||||
|
class BasicBlock(nn.Sequential):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
|
||||||
|
bn=True, act=nn.ReLU(True)):
|
||||||
|
|
||||||
|
m = [nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size,
|
||||||
|
padding=(kernel_size//2), stride=stride, bias=bias)
|
||||||
|
]
|
||||||
|
if bn: m.append(nn.BatchNorm2d(out_channels))
|
||||||
|
if act is not None: m.append(act)
|
||||||
|
super(BasicBlock, self).__init__(*m)
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, conv, n_feats, kernel_size,
|
||||||
|
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
||||||
|
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
m = []
|
||||||
|
for i in range(2):
|
||||||
|
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
||||||
|
if bn: m.append(nn.BatchNorm2d(n_feats))
|
||||||
|
if i == 0: m.append(act)
|
||||||
|
|
||||||
|
self.body = nn.Sequential(*m)
|
||||||
|
self.res_scale = res_scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.body(x).mul(self.res_scale)
|
||||||
|
res += x
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class Upsampler(nn.Sequential):
|
||||||
|
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=False):
|
||||||
|
|
||||||
|
m = []
|
||||||
|
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
||||||
|
for _ in range(int(math.log(scale, 2))):
|
||||||
|
m.append(conv(n_feats, 4 * n_feats, 3, bias))
|
||||||
|
m.append(nn.PixelShuffle(2))
|
||||||
|
if bn: m.append(nn.BatchNorm2d(n_feats))
|
||||||
|
|
||||||
|
if act == 'relu':
|
||||||
|
m.append(nn.ReLU(True))
|
||||||
|
elif act == 'prelu':
|
||||||
|
m.append(nn.PReLU(n_feats))
|
||||||
|
|
||||||
|
elif scale == 3:
|
||||||
|
m.append(conv(n_feats, 9 * n_feats, 3, bias))
|
||||||
|
m.append(nn.PixelShuffle(3))
|
||||||
|
if bn: m.append(nn.BatchNorm2d(n_feats))
|
||||||
|
|
||||||
|
if act == 'relu':
|
||||||
|
m.append(nn.ReLU(True))
|
||||||
|
elif act == 'prelu':
|
||||||
|
m.append(nn.PReLU(n_feats))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
super(Upsampler, self).__init__(*m)
|
152
codes/models/archs/mdcn/mdcn.py
Normal file
152
codes/models/archs/mdcn/mdcn.py
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
from models.archs.mdcn import common
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from utils.util import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def make_model(args, parent=False):
|
||||||
|
return MDCN(args)
|
||||||
|
|
||||||
|
|
||||||
|
class MDCB(nn.Module):
|
||||||
|
def __init__(self, conv=common.default_conv):
|
||||||
|
super(MDCB, self).__init__()
|
||||||
|
|
||||||
|
n_feats = 128
|
||||||
|
d_feats = 96
|
||||||
|
kernel_size_1 = 3
|
||||||
|
kernel_size_2 = 5
|
||||||
|
act = nn.ReLU(True)
|
||||||
|
|
||||||
|
self.conv_3_1 = conv(n_feats, n_feats, kernel_size_1)
|
||||||
|
self.conv_3_2 = conv(d_feats, d_feats, kernel_size_1)
|
||||||
|
self.conv_5_1 = conv(n_feats, n_feats, kernel_size_2)
|
||||||
|
self.conv_5_2 = conv(d_feats, d_feats, kernel_size_2)
|
||||||
|
self.confusion_3 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True)
|
||||||
|
self.confusion_5 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True)
|
||||||
|
self.confusion_bottle = nn.Conv2d(n_feats * 3 + d_feats * 2, n_feats, 1, padding=0, bias=True)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
input_1 = x
|
||||||
|
output_3_1 = self.relu(self.conv_3_1(input_1))
|
||||||
|
output_5_1 = self.relu(self.conv_5_1(input_1))
|
||||||
|
input_2 = torch.cat([input_1, output_3_1, output_5_1], 1)
|
||||||
|
input_2_3 = self.confusion_3(input_2)
|
||||||
|
input_2_5 = self.confusion_5(input_2)
|
||||||
|
|
||||||
|
output_3_2 = self.relu(self.conv_3_2(input_2_3))
|
||||||
|
output_5_2 = self.relu(self.conv_5_2(input_2_5))
|
||||||
|
input_3 = torch.cat([input_1, output_3_1, output_5_1, output_3_2, output_5_2], 1)
|
||||||
|
output = self.confusion_bottle(input_3)
|
||||||
|
output += x
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class CALayer(nn.Module):
|
||||||
|
def __init__(self, n_feats, reduction=16):
|
||||||
|
super(CALayer, self).__init__()
|
||||||
|
# global average pooling: feature --> point
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
# feature channel downscale and upscale --> channel weight
|
||||||
|
self.conv_du = nn.Sequential(
|
||||||
|
nn.Conv2d(n_feats, n_feats // reduction, 1, padding=0, bias=True),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(n_feats // reduction, n_feats, 1, padding=0, bias=True),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.avg_pool(x)
|
||||||
|
y = self.conv_du(y)
|
||||||
|
return x * y
|
||||||
|
|
||||||
|
|
||||||
|
class DB(nn.Module):
|
||||||
|
def __init__(self, conv=common.default_conv):
|
||||||
|
super(DB, self).__init__()
|
||||||
|
|
||||||
|
n_feats = 128
|
||||||
|
d_feats = 96
|
||||||
|
n_blocks = 12
|
||||||
|
|
||||||
|
self.fushion_down = nn.Conv2d(n_feats * (n_blocks - 1), d_feats, 1, padding=0, bias=True)
|
||||||
|
self.channel_attention = CALayer(d_feats)
|
||||||
|
self.fushion_up = nn.Conv2d(d_feats, n_feats, 1, padding=0, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fushion_down(x)
|
||||||
|
x = self.channel_attention(x)
|
||||||
|
x = self.fushion_up(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MDCN(nn.Module):
|
||||||
|
def __init__(self, args, conv=common.default_conv):
|
||||||
|
super(MDCN, self).__init__()
|
||||||
|
n_feats = 128
|
||||||
|
kernel_size = 3
|
||||||
|
self.scale_idx = 0
|
||||||
|
act = nn.ReLU(True)
|
||||||
|
|
||||||
|
n_blocks = 12
|
||||||
|
self.n_blocks = n_blocks
|
||||||
|
|
||||||
|
# RGB mean for DIV2K
|
||||||
|
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||||
|
rgb_std = (1.0, 1.0, 1.0)
|
||||||
|
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
|
||||||
|
|
||||||
|
# define head module
|
||||||
|
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
|
||||||
|
|
||||||
|
# define body module
|
||||||
|
modules_body = nn.ModuleList()
|
||||||
|
for i in range(n_blocks):
|
||||||
|
modules_body.append(MDCB())
|
||||||
|
|
||||||
|
# define distillation module
|
||||||
|
modules_dist = nn.ModuleList()
|
||||||
|
modules_dist.append(DB())
|
||||||
|
|
||||||
|
modules_transform = [conv(n_feats, n_feats, kernel_size)]
|
||||||
|
self.upsample = nn.ModuleList([
|
||||||
|
common.Upsampler(
|
||||||
|
conv, s, n_feats, act=True
|
||||||
|
) for s in args.scale
|
||||||
|
])
|
||||||
|
modules_rebult = [conv(n_feats, args.n_colors, kernel_size)]
|
||||||
|
|
||||||
|
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
|
||||||
|
|
||||||
|
self.head = nn.Sequential(*modules_head)
|
||||||
|
self.body = nn.Sequential(*modules_body)
|
||||||
|
self.dist = nn.Sequential(*modules_dist)
|
||||||
|
self.transform = nn.Sequential(*modules_transform)
|
||||||
|
self.rebult = nn.Sequential(*modules_rebult)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.sub_mean(x)
|
||||||
|
x = checkpoint(self.head, x)
|
||||||
|
front = x
|
||||||
|
|
||||||
|
MDCB_out = []
|
||||||
|
for i in range(self.n_blocks):
|
||||||
|
x = checkpoint(self.body[i], x)
|
||||||
|
if i != (self.n_blocks - 1):
|
||||||
|
MDCB_out.append(x)
|
||||||
|
|
||||||
|
hierarchical = torch.cat(MDCB_out, 1)
|
||||||
|
hierarchical = checkpoint(self.dist, hierarchical)
|
||||||
|
|
||||||
|
mix = front + hierarchical + x
|
||||||
|
|
||||||
|
out = checkpoint(self.transform, mix)
|
||||||
|
out = self.upsample[self.scale_idx](out)
|
||||||
|
out = checkpoint(self.rebult, out)
|
||||||
|
out = self.add_mean(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def set_scale(self, scale_idx):
|
||||||
|
self.scale_idx = scale_idx
|
|
@ -173,6 +173,10 @@ def define_G(opt, opt_net, scale=None):
|
||||||
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||||
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
|
||||||
initial_conv_stride=opt_net['initial_stride'])
|
initial_conv_stride=opt_net['initial_stride'])
|
||||||
|
elif which_model == 'mdcn':
|
||||||
|
from models.archs.mdcn.mdcn import MDCN
|
||||||
|
args = munchify({'scale': opt_net['scale'], 'n_colors': 3, 'rgb_range': 1.0})
|
||||||
|
netG = MDCN(args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
return netG
|
return netG
|
||||||
|
|
Loading…
Reference in New Issue
Block a user