Delete mdcn & panet
Garbage, all of it.
This commit is contained in:
parent
f2880b33c9
commit
c18adbd606
|
@ -1,86 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
|
||||||
return nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size,
|
|
||||||
padding=(kernel_size//2), bias=bias)
|
|
||||||
|
|
||||||
class MeanShift(nn.Conv2d):
|
|
||||||
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
|
|
||||||
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
|
||||||
std = torch.Tensor(rgb_std)
|
|
||||||
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
|
||||||
self.weight.data.div_(std.view(3, 1, 1, 1))
|
|
||||||
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
|
||||||
self.bias.data.div_(std)
|
|
||||||
self.requires_grad = False
|
|
||||||
|
|
||||||
class BasicBlock(nn.Sequential):
|
|
||||||
def __init__(
|
|
||||||
self, in_channels, out_channels, kernel_size, stride=1, bias=False,
|
|
||||||
bn=True, act=nn.ReLU(True)):
|
|
||||||
|
|
||||||
m = [nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size,
|
|
||||||
padding=(kernel_size//2), stride=stride, bias=bias)
|
|
||||||
]
|
|
||||||
if bn: m.append(nn.BatchNorm2d(out_channels))
|
|
||||||
if act is not None: m.append(act)
|
|
||||||
super(BasicBlock, self).__init__(*m)
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, conv, n_feats, kernel_size,
|
|
||||||
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
|
|
||||||
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
m = []
|
|
||||||
for i in range(2):
|
|
||||||
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
|
||||||
if bn: m.append(nn.BatchNorm2d(n_feats))
|
|
||||||
if i == 0: m.append(act)
|
|
||||||
|
|
||||||
self.body = nn.Sequential(*m)
|
|
||||||
self.res_scale = res_scale
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
res = self.body(x).mul(self.res_scale)
|
|
||||||
res += x
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class Upsampler(nn.Sequential):
|
|
||||||
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=False):
|
|
||||||
|
|
||||||
m = []
|
|
||||||
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
|
||||||
for _ in range(int(math.log(scale, 2))):
|
|
||||||
m.append(conv(n_feats, 4 * n_feats, 3, bias))
|
|
||||||
m.append(nn.PixelShuffle(2))
|
|
||||||
if bn: m.append(nn.BatchNorm2d(n_feats))
|
|
||||||
|
|
||||||
if act == 'relu':
|
|
||||||
m.append(nn.ReLU(True))
|
|
||||||
elif act == 'prelu':
|
|
||||||
m.append(nn.PReLU(n_feats))
|
|
||||||
|
|
||||||
elif scale == 3:
|
|
||||||
m.append(conv(n_feats, 9 * n_feats, 3, bias))
|
|
||||||
m.append(nn.PixelShuffle(3))
|
|
||||||
if bn: m.append(nn.BatchNorm2d(n_feats))
|
|
||||||
|
|
||||||
if act == 'relu':
|
|
||||||
m.append(nn.ReLU(True))
|
|
||||||
elif act == 'prelu':
|
|
||||||
m.append(nn.PReLU(n_feats))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
super(Upsampler, self).__init__(*m)
|
|
|
@ -1,143 +0,0 @@
|
||||||
from models.archs.mdcn import common
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from utils.util import checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def make_model(args, parent=False):
|
|
||||||
return MDCN(args)
|
|
||||||
|
|
||||||
|
|
||||||
class MDCB(nn.Module):
|
|
||||||
def __init__(self, conv=common.default_conv):
|
|
||||||
super(MDCB, self).__init__()
|
|
||||||
|
|
||||||
n_feats = 128
|
|
||||||
d_feats = 96
|
|
||||||
kernel_size_1 = 3
|
|
||||||
kernel_size_2 = 5
|
|
||||||
act = nn.ReLU(True)
|
|
||||||
|
|
||||||
self.conv_3_1 = conv(n_feats, n_feats, kernel_size_1)
|
|
||||||
self.conv_3_2 = conv(d_feats, d_feats, kernel_size_1)
|
|
||||||
self.conv_5_1 = conv(n_feats, n_feats, kernel_size_2)
|
|
||||||
self.conv_5_2 = conv(d_feats, d_feats, kernel_size_2)
|
|
||||||
self.confusion_3 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True)
|
|
||||||
self.confusion_5 = nn.Conv2d(n_feats * 3, d_feats, 1, padding=0, bias=True)
|
|
||||||
self.confusion_bottle = nn.Conv2d(n_feats * 3 + d_feats * 2, n_feats, 1, padding=0, bias=True)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
input_1 = x
|
|
||||||
output_3_1 = self.relu(self.conv_3_1(input_1))
|
|
||||||
output_5_1 = self.relu(self.conv_5_1(input_1))
|
|
||||||
input_2 = torch.cat([input_1, output_3_1, output_5_1], 1)
|
|
||||||
input_2_3 = self.confusion_3(input_2)
|
|
||||||
input_2_5 = self.confusion_5(input_2)
|
|
||||||
|
|
||||||
output_3_2 = self.relu(self.conv_3_2(input_2_3))
|
|
||||||
output_5_2 = self.relu(self.conv_5_2(input_2_5))
|
|
||||||
input_3 = torch.cat([input_1, output_3_1, output_5_1, output_3_2, output_5_2], 1)
|
|
||||||
output = self.confusion_bottle(input_3)
|
|
||||||
output += x
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class CALayer(nn.Module):
|
|
||||||
def __init__(self, n_feats, reduction=16):
|
|
||||||
super(CALayer, self).__init__()
|
|
||||||
# global average pooling: feature --> point
|
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
# feature channel downscale and upscale --> channel weight
|
|
||||||
self.conv_du = nn.Sequential(
|
|
||||||
nn.Conv2d(n_feats, n_feats // reduction, 1, padding=0, bias=True),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(n_feats // reduction, n_feats, 1, padding=0, bias=True),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y = self.avg_pool(x)
|
|
||||||
y = self.conv_du(y)
|
|
||||||
return x * y
|
|
||||||
|
|
||||||
|
|
||||||
class DB(nn.Module):
|
|
||||||
def __init__(self, conv=common.default_conv):
|
|
||||||
super(DB, self).__init__()
|
|
||||||
|
|
||||||
n_feats = 128
|
|
||||||
d_feats = 96
|
|
||||||
n_blocks = 12
|
|
||||||
|
|
||||||
self.fushion_down = nn.Conv2d(n_feats * (n_blocks - 1), d_feats, 1, padding=0, bias=True)
|
|
||||||
self.channel_attention = CALayer(d_feats)
|
|
||||||
self.fushion_up = nn.Conv2d(d_feats, n_feats, 1, padding=0, bias=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fushion_down(x)
|
|
||||||
x = self.channel_attention(x)
|
|
||||||
x = self.fushion_up(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MDCN(nn.Module):
|
|
||||||
def __init__(self, args, conv=common.default_conv):
|
|
||||||
super(MDCN, self).__init__()
|
|
||||||
n_feats = 128
|
|
||||||
kernel_size = 3
|
|
||||||
self.scale_idx = 0
|
|
||||||
act = nn.ReLU(True)
|
|
||||||
|
|
||||||
n_blocks = 12
|
|
||||||
self.n_blocks = n_blocks
|
|
||||||
|
|
||||||
# define head module
|
|
||||||
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
|
|
||||||
|
|
||||||
# define body module
|
|
||||||
modules_body = nn.ModuleList()
|
|
||||||
for i in range(n_blocks):
|
|
||||||
modules_body.append(MDCB())
|
|
||||||
|
|
||||||
# define distillation module
|
|
||||||
modules_dist = nn.ModuleList()
|
|
||||||
modules_dist.append(DB())
|
|
||||||
|
|
||||||
modules_transform = [conv(n_feats, n_feats, kernel_size)]
|
|
||||||
self.upsample = nn.ModuleList([
|
|
||||||
common.Upsampler(
|
|
||||||
conv, s, n_feats, act=True
|
|
||||||
) for s in args.scale
|
|
||||||
])
|
|
||||||
modules_rebult = [conv(n_feats, args.n_colors, kernel_size)]
|
|
||||||
|
|
||||||
self.head = nn.Sequential(*modules_head)
|
|
||||||
self.body = nn.Sequential(*modules_body)
|
|
||||||
self.dist = nn.Sequential(*modules_dist)
|
|
||||||
self.transform = nn.Sequential(*modules_transform)
|
|
||||||
self.rebult = nn.Sequential(*modules_rebult)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = checkpoint(self.head, x)
|
|
||||||
front = x
|
|
||||||
|
|
||||||
MDCB_out = []
|
|
||||||
for i in range(self.n_blocks):
|
|
||||||
x = checkpoint(self.body[i], x)
|
|
||||||
if i != (self.n_blocks - 1):
|
|
||||||
MDCB_out.append(x)
|
|
||||||
|
|
||||||
hierarchical = torch.cat(MDCB_out, 1)
|
|
||||||
hierarchical = checkpoint(self.dist, hierarchical)
|
|
||||||
|
|
||||||
mix = front + hierarchical + x
|
|
||||||
|
|
||||||
out = checkpoint(self.transform, mix)
|
|
||||||
out = self.upsample[self.scale_idx](out)
|
|
||||||
out = checkpoint(self.rebult, out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def set_scale(self, scale_idx):
|
|
||||||
self.scale_idx = scale_idx
|
|
|
@ -1,97 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchvision import transforms
|
|
||||||
from torchvision import utils as vutils
|
|
||||||
import models.archs.panet.common as common
|
|
||||||
from models.archs.panet.tools import extract_image_patches, \
|
|
||||||
reduce_mean, reduce_sum, same_padding
|
|
||||||
from utils.util import checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
class PyramidAttention(nn.Module):
|
|
||||||
def __init__(self, level=5, res_scale=1, channel=64, reduction=2, ksize=3, stride=1, softmax_scale=10, average=True,
|
|
||||||
conv=common.default_conv):
|
|
||||||
super(PyramidAttention, self).__init__()
|
|
||||||
self.ksize = ksize
|
|
||||||
self.stride = stride
|
|
||||||
self.res_scale = res_scale
|
|
||||||
self.softmax_scale = softmax_scale
|
|
||||||
self.scale = [1 - i / 10 for i in range(level)]
|
|
||||||
self.average = average
|
|
||||||
escape_NaN = torch.FloatTensor([1e-4])
|
|
||||||
self.register_buffer('escape_NaN', escape_NaN)
|
|
||||||
self.conv_match_L_base = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU())
|
|
||||||
self.conv_match = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU())
|
|
||||||
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
res = input
|
|
||||||
# theta
|
|
||||||
match_base = self.conv_match_L_base(input)
|
|
||||||
shape_base = list(res.size())
|
|
||||||
input_groups = torch.split(match_base, 1, dim=0)
|
|
||||||
# patch size for matching
|
|
||||||
kernel = self.ksize
|
|
||||||
# raw_w is for reconstruction
|
|
||||||
raw_w = []
|
|
||||||
# w is for matching
|
|
||||||
w = []
|
|
||||||
# build feature pyramid
|
|
||||||
for i in range(len(self.scale)):
|
|
||||||
ref = input
|
|
||||||
if self.scale[i] != 1:
|
|
||||||
ref = F.interpolate(input, scale_factor=self.scale[i], mode='bicubic')
|
|
||||||
# feature transformation function f
|
|
||||||
base = self.conv_assembly(ref)
|
|
||||||
shape_input = base.shape
|
|
||||||
# sampling
|
|
||||||
raw_w_i = extract_image_patches(base, ksizes=[kernel, kernel],
|
|
||||||
strides=[self.stride, self.stride],
|
|
||||||
rates=[1, 1],
|
|
||||||
padding='same') # [N, C*k*k, L]
|
|
||||||
raw_w_i = raw_w_i.view(shape_input[0], shape_input[1], kernel, kernel, -1)
|
|
||||||
raw_w_i = raw_w_i.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
|
|
||||||
raw_w_i_groups = torch.split(raw_w_i, 1, dim=0)
|
|
||||||
raw_w.append(raw_w_i_groups)
|
|
||||||
|
|
||||||
# feature transformation function g
|
|
||||||
ref_i = self.conv_match(ref)
|
|
||||||
shape_ref = ref_i.shape
|
|
||||||
# sampling
|
|
||||||
w_i = extract_image_patches(ref_i, ksizes=[self.ksize, self.ksize],
|
|
||||||
strides=[self.stride, self.stride],
|
|
||||||
rates=[1, 1],
|
|
||||||
padding='same')
|
|
||||||
w_i = w_i.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
|
|
||||||
w_i = w_i.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
|
|
||||||
w_i_groups = torch.split(w_i, 1, dim=0)
|
|
||||||
w.append(w_i_groups)
|
|
||||||
|
|
||||||
y = []
|
|
||||||
for idx, xi in enumerate(input_groups):
|
|
||||||
# group in a filter
|
|
||||||
wi = torch.cat([w[i][idx][0] for i in range(len(self.scale))], dim=0) # [L, C, k, k]
|
|
||||||
# normalize
|
|
||||||
max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
|
|
||||||
axis=[1, 2, 3],
|
|
||||||
keepdim=True)),
|
|
||||||
self.escape_NaN)
|
|
||||||
wi_normed = wi / max_wi
|
|
||||||
# matching
|
|
||||||
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
|
|
||||||
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
|
|
||||||
yi = yi.view(1, wi.shape[0], shape_base[2], shape_base[3]) # (B=1, C=32*32, H=32, W=32)
|
|
||||||
# softmax matching score
|
|
||||||
yi = F.softmax(yi * self.softmax_scale, dim=1)
|
|
||||||
|
|
||||||
if self.average == False:
|
|
||||||
yi = (yi == yi.max(dim=1, keepdim=True)[0]).float()
|
|
||||||
|
|
||||||
# deconv for patch pasting
|
|
||||||
raw_wi = torch.cat([raw_w[i][idx][0] for i in range(len(self.scale))], dim=0)
|
|
||||||
yi = F.conv_transpose2d(yi, raw_wi, stride=self.stride, padding=1) / 4.
|
|
||||||
y.append(yi)
|
|
||||||
|
|
||||||
y = torch.cat(y, dim=0) + res * self.res_scale # back to the mini-batch
|
|
||||||
return y
|
|
|
@ -1,87 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
|
|
||||||
return nn.Conv2d(
|
|
||||||
in_channels, out_channels, kernel_size,
|
|
||||||
padding=(kernel_size//2),stride=stride, bias=bias)
|
|
||||||
|
|
||||||
class MeanShift(nn.Conv2d):
|
|
||||||
def __init__(
|
|
||||||
self, rgb_range,
|
|
||||||
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
|
|
||||||
|
|
||||||
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
|
||||||
std = torch.Tensor(rgb_std)
|
|
||||||
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
|
|
||||||
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
|
|
||||||
for p in self.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
class BasicBlock(nn.Sequential):
|
|
||||||
def __init__(
|
|
||||||
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
|
|
||||||
bn=False, act=nn.PReLU()):
|
|
||||||
|
|
||||||
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
|
|
||||||
if bn:
|
|
||||||
m.append(nn.BatchNorm2d(out_channels))
|
|
||||||
if act is not None:
|
|
||||||
m.append(act)
|
|
||||||
|
|
||||||
super(BasicBlock, self).__init__(*m)
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, conv, n_feats, kernel_size,
|
|
||||||
bias=True, bn=False, act=nn.PReLU(), res_scale=1):
|
|
||||||
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
m = []
|
|
||||||
for i in range(2):
|
|
||||||
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
|
||||||
if bn:
|
|
||||||
m.append(nn.BatchNorm2d(n_feats))
|
|
||||||
if i == 0:
|
|
||||||
m.append(act)
|
|
||||||
|
|
||||||
self.body = nn.Sequential(*m)
|
|
||||||
self.res_scale = res_scale
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
res = self.body(x).mul(self.res_scale)
|
|
||||||
res += x
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
class Upsampler(nn.Sequential):
|
|
||||||
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
|
|
||||||
|
|
||||||
m = []
|
|
||||||
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
|
|
||||||
for _ in range(int(math.log(scale, 2))):
|
|
||||||
m.append(conv(n_feats, 4 * n_feats, 3, bias))
|
|
||||||
m.append(nn.PixelShuffle(2))
|
|
||||||
if bn:
|
|
||||||
m.append(nn.BatchNorm2d(n_feats))
|
|
||||||
if act == 'relu':
|
|
||||||
m.append(nn.ReLU(True))
|
|
||||||
elif act == 'prelu':
|
|
||||||
m.append(nn.PReLU(n_feats))
|
|
||||||
|
|
||||||
elif scale == 3:
|
|
||||||
m.append(conv(n_feats, 9 * n_feats, 3, bias))
|
|
||||||
m.append(nn.PixelShuffle(3))
|
|
||||||
if bn:
|
|
||||||
m.append(nn.BatchNorm2d(n_feats))
|
|
||||||
if act == 'relu':
|
|
||||||
m.append(nn.ReLU(True))
|
|
||||||
elif act == 'prelu':
|
|
||||||
m.append(nn.PReLU(n_feats))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
super(Upsampler, self).__init__(*m)
|
|
|
@ -1,91 +0,0 @@
|
||||||
from models.archs.panet import common
|
|
||||||
from models.archs.panet import attention
|
|
||||||
import torch.nn as nn
|
|
||||||
from utils.util import checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def make_model(args, parent=False):
|
|
||||||
return PANET(args)
|
|
||||||
|
|
||||||
|
|
||||||
class PANET(nn.Module):
|
|
||||||
def __init__(self, args, conv=common.default_conv):
|
|
||||||
super(PANET, self).__init__()
|
|
||||||
|
|
||||||
n_resblocks = args.n_resblocks
|
|
||||||
n_feats = args.n_feats
|
|
||||||
kernel_size = 3
|
|
||||||
scale = args.scale[0]
|
|
||||||
|
|
||||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
|
||||||
rgb_std = (1.0, 1.0, 1.0)
|
|
||||||
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
|
|
||||||
self.msa = attention.PyramidAttention()
|
|
||||||
# define head module
|
|
||||||
m_head = [conv(args.n_colors, n_feats, kernel_size)]
|
|
||||||
|
|
||||||
# define body module
|
|
||||||
m_body = [
|
|
||||||
common.ResBlock(
|
|
||||||
conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale
|
|
||||||
) for _ in range(n_resblocks // 2)
|
|
||||||
]
|
|
||||||
m_body.append(self.msa)
|
|
||||||
for i in range(n_resblocks // 2):
|
|
||||||
m_body.append(common.ResBlock(conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale))
|
|
||||||
|
|
||||||
m_body.append(conv(n_feats, n_feats, kernel_size))
|
|
||||||
|
|
||||||
# define tail module
|
|
||||||
# m_tail = [
|
|
||||||
# common.Upsampler(conv, scale, n_feats, act=False),
|
|
||||||
# conv(n_feats, args.n_colors, kernel_size)
|
|
||||||
# ]
|
|
||||||
m_tail = [
|
|
||||||
common.Upsampler(conv, scale, n_feats, act=False),
|
|
||||||
conv(n_feats, args.n_colors, kernel_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
|
|
||||||
|
|
||||||
self.head = nn.Sequential(*m_head)
|
|
||||||
self.body = nn.ModuleList(m_body)
|
|
||||||
self.tail = nn.Sequential(*m_tail)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x = self.sub_mean(x)
|
|
||||||
x = self.head(x)
|
|
||||||
|
|
||||||
res = x
|
|
||||||
for b in self.body:
|
|
||||||
if b == self.msa:
|
|
||||||
if __name__ == '__main__':
|
|
||||||
res = self.msa(res)
|
|
||||||
else:
|
|
||||||
res = checkpoint(b, res)
|
|
||||||
|
|
||||||
res += x
|
|
||||||
|
|
||||||
x = checkpoint(self.tail, res)
|
|
||||||
# x = self.add_mean(x)
|
|
||||||
|
|
||||||
return x,
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
|
||||||
own_state = self.state_dict()
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name in own_state:
|
|
||||||
if isinstance(param, nn.Parameter):
|
|
||||||
param = param.data
|
|
||||||
try:
|
|
||||||
own_state[name].copy_(param)
|
|
||||||
except Exception:
|
|
||||||
if name.find('tail') == -1:
|
|
||||||
raise RuntimeError('While copying the parameter named {}, '
|
|
||||||
'whose dimensions in the model are {} and '
|
|
||||||
'whose dimensions in the checkpoint are {}.'
|
|
||||||
.format(name, own_state[name].size(), param.size()))
|
|
||||||
elif strict:
|
|
||||||
if name.find('tail') == -1:
|
|
||||||
raise KeyError('unexpected key "{}" in state_dict'
|
|
||||||
.format(name))
|
|
|
@ -1,84 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(x):
|
|
||||||
return x.mul_(2).add_(-1)
|
|
||||||
|
|
||||||
|
|
||||||
def same_padding(images, ksizes, strides, rates):
|
|
||||||
assert len(images.size()) == 4
|
|
||||||
batch_size, channel, rows, cols = images.size()
|
|
||||||
out_rows = (rows + strides[0] - 1) // strides[0]
|
|
||||||
out_cols = (cols + strides[1] - 1) // strides[1]
|
|
||||||
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
|
||||||
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
|
||||||
padding_rows = max(0, (out_rows - 1) * strides[0] + effective_k_row - rows)
|
|
||||||
padding_cols = max(0, (out_cols - 1) * strides[1] + effective_k_col - cols)
|
|
||||||
# Pad the input
|
|
||||||
padding_top = int(padding_rows / 2.)
|
|
||||||
padding_left = int(padding_cols / 2.)
|
|
||||||
padding_bottom = padding_rows - padding_top
|
|
||||||
padding_right = padding_cols - padding_left
|
|
||||||
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
|
||||||
images = torch.nn.ZeroPad2d(paddings)(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
|
||||||
"""
|
|
||||||
Extract patches from images and put them in the C output dimension.
|
|
||||||
:param padding:
|
|
||||||
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
|
||||||
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
|
||||||
each dimension of images
|
|
||||||
:param strides: [stride_rows, stride_cols]
|
|
||||||
:param rates: [dilation_rows, dilation_cols]
|
|
||||||
:return: A Tensor
|
|
||||||
"""
|
|
||||||
assert len(images.size()) == 4
|
|
||||||
assert padding in ['same', 'valid']
|
|
||||||
batch_size, channel, height, width = images.size()
|
|
||||||
|
|
||||||
if padding == 'same':
|
|
||||||
images = same_padding(images, ksizes, strides, rates)
|
|
||||||
elif padding == 'valid':
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('Unsupported padding type: {}.\
|
|
||||||
Only "same" or "valid" are supported.'.format(padding))
|
|
||||||
|
|
||||||
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
|
||||||
dilation=rates,
|
|
||||||
padding=0,
|
|
||||||
stride=strides)
|
|
||||||
patches = unfold(images)
|
|
||||||
return patches # [N, C*k*k, L], L is the total number of such blocks
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_mean(x, axis=None, keepdim=False):
|
|
||||||
if not axis:
|
|
||||||
axis = range(len(x.shape))
|
|
||||||
for i in sorted(axis, reverse=True):
|
|
||||||
x = torch.mean(x, dim=i, keepdim=keepdim)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_std(x, axis=None, keepdim=False):
|
|
||||||
if not axis:
|
|
||||||
axis = range(len(x.shape))
|
|
||||||
for i in sorted(axis, reverse=True):
|
|
||||||
x = torch.std(x, dim=i, keepdim=keepdim)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_sum(x, axis=None, keepdim=False):
|
|
||||||
if not axis:
|
|
||||||
axis = range(len(x.shape))
|
|
||||||
for i in sorted(axis, reverse=True):
|
|
||||||
x = torch.sum(x, dim=i, keepdim=keepdim)
|
|
||||||
return x
|
|
Loading…
Reference in New Issue
Block a user