Re-add original SPSR_arch
This commit is contained in:
parent
31cf1ac98d
commit
8202ee72b9
|
@ -69,6 +69,7 @@ class ImageGradientNoPadding(nn.Module):
|
|||
|
||||
|
||||
def forward(self, x):
|
||||
x = x.float()
|
||||
x_list = []
|
||||
for i in range(x.shape[1]):
|
||||
x_i = x[:, i]
|
||||
|
@ -86,6 +87,139 @@ class ImageGradientNoPadding(nn.Module):
|
|||
# Generator
|
||||
####################
|
||||
|
||||
class SPSRNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||
super(SPSRNet, self).__init__()
|
||||
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
rb_blocks = [RRDB(nf) for _ in range(nb)]
|
||||
|
||||
LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
upsample_block = UpconvBlock
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, activation=True)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, activation=True) for _ in range(n_upscale)]
|
||||
|
||||
self.HR_conv0_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||
self.HR_conv1_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
|
||||
*upsampler, self.HR_conv0_new)
|
||||
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_1 = RRDB(nf * 2)
|
||||
|
||||
self.b_concat_2 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_2 = RRDB(nf * 2)
|
||||
|
||||
self.b_concat_3 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_3 = RRDB(nf * 2)
|
||||
|
||||
self.b_concat_4 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_4 = RRDB(nf * 2)
|
||||
|
||||
self.b_LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
if upscale == 3:
|
||||
b_upsampler = UpconvBlock(nf, nf, activation=True)
|
||||
else:
|
||||
b_upsampler = [UpconvBlock(nf, nf, activation=True) for _ in range(n_upscale)]
|
||||
|
||||
b_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||
b_HR_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
||||
|
||||
self.conv_w = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False)
|
||||
|
||||
self.f_concat = ConvGnLelu(nf * 2, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.f_block = RRDB(nf * 2)
|
||||
|
||||
self.f_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||
self.f_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
x = self.model[0](x)
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
|
||||
x_ori = x
|
||||
for i in range(5):
|
||||
x = block_list[i](x)
|
||||
x_fea1 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i + 5](x)
|
||||
x_fea2 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i + 10](x)
|
||||
x_fea3 = x
|
||||
|
||||
for i in range(5):
|
||||
x = block_list[i + 15](x)
|
||||
x_fea4 = x
|
||||
|
||||
x = block_list[20:](x)
|
||||
# short cut
|
||||
x = x_ori + x
|
||||
x = self.model[2:](x)
|
||||
x = self.HR_conv1_new(x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||
|
||||
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||
|
||||
x_cat_2 = self.b_block_2(x_cat_2)
|
||||
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||
|
||||
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||
|
||||
x_cat_3 = self.b_block_3(x_cat_3)
|
||||
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||
|
||||
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||
|
||||
x_cat_4 = self.b_block_4(x_cat_4)
|
||||
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||
|
||||
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||
|
||||
# short cut
|
||||
x_cat_4 = x_cat_4 + x_b_fea
|
||||
x_branch = self.b_module(x_cat_4)
|
||||
|
||||
x_out_branch = self.conv_w(x_branch)
|
||||
########
|
||||
x_branch_d = x_branch
|
||||
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
||||
x_f_cat = self.f_block(x_f_cat)
|
||||
x_out = self.f_concat(x_f_cat)
|
||||
x_out = self.f_HR_conv0(x_out)
|
||||
x_out = self.f_HR_conv1(x_out)
|
||||
|
||||
#########
|
||||
return x_out_branch, x_out, x_grad
|
||||
|
||||
|
||||
class SPSRNetSimplified(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, upscale=4):
|
||||
|
|
|
@ -59,6 +59,9 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||
elif which_model == 'spsr':
|
||||
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
elif which_model == 'spsr_net_improved':
|
||||
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user