Re-add original SPSR_arch

This commit is contained in:
James Betker 2020-10-27 11:00:38 -06:00
parent 31cf1ac98d
commit 8202ee72b9
2 changed files with 137 additions and 0 deletions

View File

@ -69,6 +69,7 @@ class ImageGradientNoPadding(nn.Module):
def forward(self, x):
x = x.float()
x_list = []
for i in range(x.shape[1]):
x_i = x[:, i]
@ -86,6 +87,139 @@ class ImageGradientNoPadding(nn.Module):
# Generator
####################
class SPSRNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(SPSRNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
rb_blocks = [RRDB(nf) for _ in range(nb)]
LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
upsample_block = UpconvBlock
if upscale == 3:
upsampler = upsample_block(nf, nf, activation=True)
else:
upsampler = [upsample_block(nf, nf, activation=True) for _ in range(n_upscale)]
self.HR_conv0_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
self.HR_conv1_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
*upsampler, self.HR_conv0_new)
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_1 = RRDB(nf * 2)
self.b_concat_2 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_2 = RRDB(nf * 2)
self.b_concat_3 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_3 = RRDB(nf * 2)
self.b_concat_4 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_4 = RRDB(nf * 2)
self.b_LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
if upscale == 3:
b_upsampler = UpconvBlock(nf, nf, activation=True)
else:
b_upsampler = [UpconvBlock(nf, nf, activation=True) for _ in range(n_upscale)]
b_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
b_HR_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
self.conv_w = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False)
self.f_concat = ConvGnLelu(nf * 2, nf, kernel_size=3, norm=False, activation=False)
self.f_block = RRDB(nf * 2)
self.f_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
self.f_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False)
self.get_g_nopadding = ImageGradientNoPadding()
def forward(self, x):
x_grad = self.get_g_nopadding(x)
x = self.model[0](x)
x, block_list = self.model[1](x)
x_ori = x
for i in range(5):
x = block_list[i](x)
x_fea1 = x
for i in range(5):
x = block_list[i + 5](x)
x_fea2 = x
for i in range(5):
x = block_list[i + 10](x)
x_fea3 = x
for i in range(5):
x = block_list[i + 15](x)
x_fea4 = x
x = block_list[20:](x)
# short cut
x = x_ori + x
x = self.model[2:](x)
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)
x_cat_1 = self.b_concat_1(x_cat_1)
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
x_cat_2 = self.b_block_2(x_cat_2)
x_cat_2 = self.b_concat_2(x_cat_2)
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
x_cat_3 = self.b_block_3(x_cat_3)
x_cat_3 = self.b_concat_3(x_cat_3)
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
x_cat_4 = self.b_block_4(x_cat_4)
x_cat_4 = self.b_concat_4(x_cat_4)
x_cat_4 = self.b_LR_conv(x_cat_4)
# short cut
x_cat_4 = x_cat_4 + x_b_fea
x_branch = self.b_module(x_cat_4)
x_out_branch = self.conv_w(x_branch)
########
x_branch_d = x_branch
x_f_cat = torch.cat([x_branch_d, x], dim=1)
x_f_cat = self.f_block(x_f_cat)
x_out = self.f_concat(x_f_cat)
x_out = self.f_HR_conv0(x_out)
x_out = self.f_HR_conv1(x_out)
#########
return x_out_branch, x_out, x_grad
class SPSRNetSimplified(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, upscale=4):

View File

@ -59,6 +59,9 @@ def define_G(opt, net_key='network_G', scale=None):
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == 'spsr':
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'spsr_net_improved':
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], upscale=opt_net['scale'])