diff --git a/codes/models/archs/mdcn/common.py b/codes/models/archs/mdcn/common.py deleted file mode 100644 index f72b82d8..00000000 --- a/codes/models/archs/mdcn/common.py +++ /dev/null @@ -1,86 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from torch.autograd import Variable - -def default_conv(in_channels, out_channels, kernel_size, bias=True): - return nn.Conv2d( - in_channels, out_channels, kernel_size, - padding=(kernel_size//2), bias=bias) - -class MeanShift(nn.Conv2d): - def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): - super(MeanShift, self).__init__(3, 3, kernel_size=1) - std = torch.Tensor(rgb_std) - self.weight.data = torch.eye(3).view(3, 3, 1, 1) - self.weight.data.div_(std.view(3, 1, 1, 1)) - self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) - self.bias.data.div_(std) - self.requires_grad = False - -class BasicBlock(nn.Sequential): - def __init__( - self, in_channels, out_channels, kernel_size, stride=1, bias=False, - bn=True, act=nn.ReLU(True)): - - m = [nn.Conv2d( - in_channels, out_channels, kernel_size, - padding=(kernel_size//2), stride=stride, bias=bias) - ] - if bn: m.append(nn.BatchNorm2d(out_channels)) - if act is not None: m.append(act) - super(BasicBlock, self).__init__(*m) - -class ResBlock(nn.Module): - def __init__( - self, conv, n_feats, kernel_size, - bias=True, bn=False, act=nn.ReLU(True), res_scale=1): - - super(ResBlock, self).__init__() - m = [] - for i in range(2): - m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - if i == 0: m.append(act) - - self.body = nn.Sequential(*m) - self.res_scale = res_scale - - def forward(self, x): - res = self.body(x).mul(self.res_scale) - res += x - - return res - - -class Upsampler(nn.Sequential): - def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=False): - - m = [] - if (scale & (scale - 1)) == 0: # Is scale = 2^n? - for _ in range(int(math.log(scale, 2))): - m.append(conv(n_feats, 4 * n_feats, 3, bias)) - m.append(nn.PixelShuffle(2)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - - if act == 'relu': - m.append(nn.ReLU(True)) - elif act == 'prelu': - m.append(nn.PReLU(n_feats)) - - elif scale == 3: - m.append(conv(n_feats, 9 * n_feats, 3, bias)) - m.append(nn.PixelShuffle(3)) - if bn: m.append(nn.BatchNorm2d(n_feats)) - - if act == 'relu': - m.append(nn.ReLU(True)) - elif act == 'prelu': - m.append(nn.PReLU(n_feats)) - else: - raise NotImplementedError - - super(Upsampler, self).__init__(*m) diff --git a/codes/models/archs/mdcn/mdcn.py b/codes/models/archs/mdcn/mdcn.py deleted file mode 100644 index f3492275..00000000 --- a/codes/models/archs/mdcn/mdcn.py +++ /dev/null @@ -1,143 +0,0 @@ -from models.archs.mdcn import common -import torch -import torch.nn as nn - -from utils.util import checkpoint - - -def make_model(args, parent=False): - return MDCN(args) - - -class MDCB(nn.Module): - def __init__(self, conv=common.default_conv): - super(MDCB, self).__init__() - - n_feats = 128 - d_feats = 96 - kernel_size_1 = 3 - kernel_size_2 = 5 - act = nn.ReLU(True) - - self.conv_3_1 = conv(n_feats, n_feats, kernel_size_1) - self.conv_3_2 = conv(d_feats, d_feats, kernel_size_1) - self.conv_5_1 = conv(n_feats, n_feats, kernel_size_2) - self.conv_5_2 = conv(d_feats, d_feats, kernel_size_2) - self.confusion_3 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True) - self.confusion_5 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True) - self.confusion_bottle = nn.Conv2d(n_feats * 3 + d_feats * 2, n_feats, 1, padding=0, bias=True) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - input_1 = x - output_3_1 = self.relu(self.conv_3_1(input_1)) - output_5_1 = self.relu(self.conv_5_1(input_1)) - input_2 = torch.cat([input_1, output_3_1, output_5_1], 1) - input_2_3 = self.confusion_3(input_2) - input_2_5 = self.confusion_5(input_2) - - output_3_2 = self.relu(self.conv_3_2(input_2_3)) - output_5_2 = self.relu(self.conv_5_2(input_2_5)) - input_3 = torch.cat([input_1, output_3_1, output_5_1, output_3_2, output_5_2], 1) - output = self.confusion_bottle(input_3) - output += x - return output - - -class CALayer(nn.Module): - def __init__(self, n_feats, reduction=16): - super(CALayer, self).__init__() - # global average pooling: feature --> point - self.avg_pool = nn.AdaptiveAvgPool2d(1) - # feature channel downscale and upscale --> channel weight - self.conv_du = nn.Sequential( - nn.Conv2d(n_feats, n_feats // reduction, 1, padding=0, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d(n_feats // reduction, n_feats, 1, padding=0, bias=True), - nn.Sigmoid() - ) - - def forward(self, x): - y = self.avg_pool(x) - y = self.conv_du(y) - return x * y - - -class DB(nn.Module): - def __init__(self, conv=common.default_conv): - super(DB, self).__init__() - - n_feats = 128 - d_feats = 96 - n_blocks = 12 - - self.fushion_down = nn.Conv2d(n_feats * (n_blocks - 1), d_feats, 1, padding=0, bias=True) - self.channel_attention = CALayer(d_feats) - self.fushion_up = nn.Conv2d(d_feats, n_feats, 1, padding=0, bias=True) - - def forward(self, x): - x = self.fushion_down(x) - x = self.channel_attention(x) - x = self.fushion_up(x) - return x - - -class MDCN(nn.Module): - def __init__(self, args, conv=common.default_conv): - super(MDCN, self).__init__() - n_feats = 128 - kernel_size = 3 - self.scale_idx = 0 - act = nn.ReLU(True) - - n_blocks = 12 - self.n_blocks = n_blocks - - # define head module - modules_head = [conv(args.n_colors, n_feats, kernel_size)] - - # define body module - modules_body = nn.ModuleList() - for i in range(n_blocks): - modules_body.append(MDCB()) - - # define distillation module - modules_dist = nn.ModuleList() - modules_dist.append(DB()) - - modules_transform = [conv(n_feats, n_feats, kernel_size)] - self.upsample = nn.ModuleList([ - common.Upsampler( - conv, s, n_feats, act=True - ) for s in args.scale - ]) - modules_rebult = [conv(n_feats, args.n_colors, kernel_size)] - - self.head = nn.Sequential(*modules_head) - self.body = nn.Sequential(*modules_body) - self.dist = nn.Sequential(*modules_dist) - self.transform = nn.Sequential(*modules_transform) - self.rebult = nn.Sequential(*modules_rebult) - - def forward(self, x): - x = checkpoint(self.head, x) - front = x - - MDCB_out = [] - for i in range(self.n_blocks): - x = checkpoint(self.body[i], x) - if i != (self.n_blocks - 1): - MDCB_out.append(x) - - hierarchical = torch.cat(MDCB_out, 1) - hierarchical = checkpoint(self.dist, hierarchical) - - mix = front + hierarchical + x - - out = checkpoint(self.transform, mix) - out = self.upsample[self.scale_idx](out) - out = checkpoint(self.rebult, out) - return out - - def set_scale(self, scale_idx): - self.scale_idx = scale_idx \ No newline at end of file diff --git a/codes/models/archs/panet/attention.py b/codes/models/archs/panet/attention.py deleted file mode 100644 index 4ce2c20a..00000000 --- a/codes/models/archs/panet/attention.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import transforms -from torchvision import utils as vutils -import models.archs.panet.common as common -from models.archs.panet.tools import extract_image_patches, \ - reduce_mean, reduce_sum, same_padding -from utils.util import checkpoint - - -class PyramidAttention(nn.Module): - def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True, - conv=common.default_conv): - super(PyramidAttention, self).__init__() - self.ksize = ksize - self.stride = stride - self.res_scale = res_scale - self.softmax_scale = softmax_scale - self.scale = [1 - i / 10 for i in range(level)] - self.average = average - escape_NaN = torch.FloatTensor([1e-4]) - self.register_buffer('escape_NaN', escape_NaN) - self.conv_match_L_base = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) - self.conv_match = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) - self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU()) - - def forward(self, input): - res = input - # theta - match_base = self.conv_match_L_base(input) - shape_base = list(res.size()) - input_groups = torch.split(match_base, 1, dim=0) - # patch size for matching - kernel = self.ksize - # raw_w is for reconstruction - raw_w = [] - # w is for matching - w = [] - # build feature pyramid - for i in range(len(self.scale)): - ref = input - if self.scale[i] != 1: - ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic') - # feature transformation function f - base = self.conv_assembly(ref) - shape_input = base.shape - # sampling - raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel], - strides=[self.stride, self.stride], - rates=[1, 1], - padding='same') # [N, C*k*k, L] - raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1) - raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k] - raw_w_i_groups = torch.split(raw_w_i, 1, dim=0) - raw_w.append(raw_w_i_groups) - - # feature transformation function g - ref_i = self.conv_match(ref) - shape_ref = ref_i.shape - # sampling - w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize], - strides=[self.stride, self.stride], - rates=[1, 1], - padding='same') - w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1) - w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k] - w_i_groups = torch.split(w_i, 1, dim=0) - w.append(w_i_groups) - - y = [] - for idx, xi in enumerate(input_groups): - # group in a filter - wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))], dim=0) # [L, C, k, k] - # normalize - max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2), - axis=[1, 2, 3], - keepdim=True)), - self.escape_NaN) - wi_normed = wi / max_wi - # matching - xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W - yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3] - yi = yi.view(1, wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32) - # softmax matching score - yi = F.softmax(yi * self.softmax_scale, dim=1) - - if self.average == False: - yi = (yi == yi.max(dim=1, keepdim=True)[0]).float() - - # deconv for patch pasting - raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))], dim=0) - yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride, padding=1) / 4. - y.append(yi) - - y = torch.cat(y, dim=0) + res * self.res_scale # back to the mini-batch - return y \ No newline at end of file diff --git a/codes/models/archs/panet/common.py b/codes/models/archs/panet/common.py deleted file mode 100644 index b67e6a74..00000000 --- a/codes/models/archs/panet/common.py +++ /dev/null @@ -1,87 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): - return nn.Conv2d( - in_channels, out_channels, kernel_size, - padding=(kernel_size//2),stride=stride, bias=bias) - -class MeanShift(nn.Conv2d): - def __init__( - self, rgb_range, - rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): - - super(MeanShift, self).__init__(3, 3, kernel_size=1) - std = torch.Tensor(rgb_std) - self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) - self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std - for p in self.parameters(): - p.requires_grad = False - -class BasicBlock(nn.Sequential): - def __init__( - self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, - bn=False, act=nn.PReLU()): - - m = [conv(in_channels, out_channels, kernel_size, bias=bias)] - if bn: - m.append(nn.BatchNorm2d(out_channels)) - if act is not None: - m.append(act) - - super(BasicBlock, self).__init__(*m) - -class ResBlock(nn.Module): - def __init__( - self, conv, n_feats, kernel_size, - bias=True, bn=False, act=nn.PReLU(), res_scale=1): - - super(ResBlock, self).__init__() - m = [] - for i in range(2): - m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) - if bn: - m.append(nn.BatchNorm2d(n_feats)) - if i == 0: - m.append(act) - - self.body = nn.Sequential(*m) - self.res_scale = res_scale - - def forward(self, x): - res = self.body(x).mul(self.res_scale) - res += x - - return res - -class Upsampler(nn.Sequential): - def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): - - m = [] - if (scale & (scale - 1)) == 0: # Is scale = 2^n? - for _ in range(int(math.log(scale, 2))): - m.append(conv(n_feats, 4 * n_feats, 3, bias)) - m.append(nn.PixelShuffle(2)) - if bn: - m.append(nn.BatchNorm2d(n_feats)) - if act == 'relu': - m.append(nn.ReLU(True)) - elif act == 'prelu': - m.append(nn.PReLU(n_feats)) - - elif scale == 3: - m.append(conv(n_feats, 9 * n_feats, 3, bias)) - m.append(nn.PixelShuffle(3)) - if bn: - m.append(nn.BatchNorm2d(n_feats)) - if act == 'relu': - m.append(nn.ReLU(True)) - elif act == 'prelu': - m.append(nn.PReLU(n_feats)) - else: - raise NotImplementedError - - super(Upsampler, self).__init__(*m) diff --git a/codes/models/archs/panet/panet.py b/codes/models/archs/panet/panet.py deleted file mode 100644 index 9e226f7e..00000000 --- a/codes/models/archs/panet/panet.py +++ /dev/null @@ -1,91 +0,0 @@ -from models.archs.panet import common -from models.archs.panet import attention -import torch.nn as nn -from utils.util import checkpoint - - -def make_model(args, parent=False): - return PANET(args) - - -class PANET(nn.Module): - def __init__(self, args, conv=common.default_conv): - super(PANET, self).__init__() - - n_resblocks = args.n_resblocks - n_feats = args.n_feats - kernel_size = 3 - scale = args.scale[0] - - rgb_mean = (0.4488, 0.4371, 0.4040) - rgb_std = (1.0, 1.0, 1.0) - self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) - self.msa = attention.PyramidAttention() - # define head module - m_head = [conv(args.n_colors, n_feats, kernel_size)] - - # define body module - m_body = [ - common.ResBlock( - conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale - ) for _ in range(n_resblocks // 2) - ] - m_body.append(self.msa) - for i in range(n_resblocks // 2): - m_body.append(common.ResBlock(conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale)) - - m_body.append(conv(n_feats, n_feats, kernel_size)) - - # define tail module - # m_tail = [ - # common.Upsampler(conv, scale, n_feats, act=False), - # conv(n_feats, args.n_colors, kernel_size) - # ] - m_tail = [ - common.Upsampler(conv, scale, n_feats, act=False), - conv(n_feats, args.n_colors, kernel_size) - ] - - self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) - - self.head = nn.Sequential(*m_head) - self.body = nn.ModuleList(m_body) - self.tail = nn.Sequential(*m_tail) - - def forward(self, x): - # x = self.sub_mean(x) - x = self.head(x) - - res = x - for b in self.body: - if b == self.msa: - if __name__ == '__main__': - res = self.msa(res) - else: - res = checkpoint(b, res) - - res += x - - x = checkpoint(self.tail, res) - # x = self.add_mean(x) - - return x, - - def load_state_dict(self, state_dict, strict=True): - own_state = self.state_dict() - for name, param in state_dict.items(): - if name in own_state: - if isinstance(param, nn.Parameter): - param = param.data - try: - own_state[name].copy_(param) - except Exception: - if name.find('tail') == -1: - raise RuntimeError('While copying the parameter named {}, ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}.' - .format(name, own_state[name].size(), param.size())) - elif strict: - if name.find('tail') == -1: - raise KeyError('unexpected key "{}" in state_dict' - .format(name)) diff --git a/codes/models/archs/panet/tools.py b/codes/models/archs/panet/tools.py deleted file mode 100644 index 47e7f971..00000000 --- a/codes/models/archs/panet/tools.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import torch -import numpy as np -from PIL import Image - -import torch.nn.functional as F - - -def normalize(x): - return x.mul_(2).add_(-1) - - -def same_padding(images, ksizes, strides, rates): - assert len(images.size()) == 4 - batch_size, channel, rows, cols = images.size() - out_rows = (rows + strides[0] - 1) // strides[0] - out_cols = (cols + strides[1] - 1) // strides[1] - effective_k_row = (ksizes[0] - 1) * rates[0] + 1 - effective_k_col = (ksizes[1] - 1) * rates[1] + 1 - padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows) - padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols) - # Pad the input - padding_top = int(padding_rows / 2.) - padding_left = int(padding_cols / 2.) - padding_bottom = padding_rows - padding_top - padding_right = padding_cols - padding_left - paddings = (padding_left, padding_right, padding_top, padding_bottom) - images = torch.nn.ZeroPad2d(paddings)(images) - return images - - -def extract_image_patches(images, ksizes, strides, rates, padding='same'): - """ - Extract patches from images and put them in the C output dimension. - :param padding: - :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape - :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for - each dimension of images - :param strides: [stride_rows, stride_cols] - :param rates: [dilation_rows, dilation_cols] - :return: A Tensor - """ - assert len(images.size()) == 4 - assert padding in ['same', 'valid'] - batch_size, channel, height, width = images.size() - - if padding == 'same': - images = same_padding(images, ksizes, strides, rates) - elif padding == 'valid': - pass - else: - raise NotImplementedError('Unsupported padding type: {}.\ - Only "same" or "valid" are supported.'.format(padding)) - - unfold = torch.nn.Unfold(kernel_size=ksizes, - dilation=rates, - padding=0, - stride=strides) - patches = unfold(images) - return patches # [N, C*k*k, L], L is the total number of such blocks - - -def reduce_mean(x, axis=None, keepdim=False): - if not axis: - axis = range(len(x.shape)) - for i in sorted(axis, reverse=True): - x = torch.mean(x, dim=i, keepdim=keepdim) - return x - - -def reduce_std(x, axis=None, keepdim=False): - if not axis: - axis = range(len(x.shape)) - for i in sorted(axis, reverse=True): - x = torch.std(x, dim=i, keepdim=keepdim) - return x - - -def reduce_sum(x, axis=None, keepdim=False): - if not axis: - axis = range(len(x.shape)) - for i in sorted(axis, reverse=True): - x = torch.sum(x, dim=i, keepdim=keepdim) - return x