forked from mrq/DL-Art-School
Re-add original SPSR_arch
This commit is contained in:
parent
31cf1ac98d
commit
8202ee72b9
|
@ -69,6 +69,7 @@ class ImageGradientNoPadding(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
x = x.float()
|
||||||
x_list = []
|
x_list = []
|
||||||
for i in range(x.shape[1]):
|
for i in range(x.shape[1]):
|
||||||
x_i = x[:, i]
|
x_i = x[:, i]
|
||||||
|
@ -86,6 +87,139 @@ class ImageGradientNoPadding(nn.Module):
|
||||||
# Generator
|
# Generator
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
class SPSRNet(nn.Module):
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
||||||
|
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
||||||
|
super(SPSRNet, self).__init__()
|
||||||
|
|
||||||
|
n_upscale = int(math.log(upscale, 2))
|
||||||
|
|
||||||
|
if upscale == 3:
|
||||||
|
n_upscale = 1
|
||||||
|
|
||||||
|
fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
rb_blocks = [RRDB(nf) for _ in range(nb)]
|
||||||
|
|
||||||
|
LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
|
||||||
|
upsample_block = UpconvBlock
|
||||||
|
if upscale == 3:
|
||||||
|
upsampler = upsample_block(nf, nf, activation=True)
|
||||||
|
else:
|
||||||
|
upsampler = [upsample_block(nf, nf, activation=True) for _ in range(n_upscale)]
|
||||||
|
|
||||||
|
self.HR_conv0_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||||
|
self.HR_conv1_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
|
||||||
|
*upsampler, self.HR_conv0_new)
|
||||||
|
|
||||||
|
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
self.b_block_1 = RRDB(nf * 2)
|
||||||
|
|
||||||
|
self.b_concat_2 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
self.b_block_2 = RRDB(nf * 2)
|
||||||
|
|
||||||
|
self.b_concat_3 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
self.b_block_3 = RRDB(nf * 2)
|
||||||
|
|
||||||
|
self.b_concat_4 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
self.b_block_4 = RRDB(nf * 2)
|
||||||
|
|
||||||
|
self.b_LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
|
||||||
|
if upscale == 3:
|
||||||
|
b_upsampler = UpconvBlock(nf, nf, activation=True)
|
||||||
|
else:
|
||||||
|
b_upsampler = [UpconvBlock(nf, nf, activation=True) for _ in range(n_upscale)]
|
||||||
|
|
||||||
|
b_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||||
|
b_HR_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
|
||||||
|
|
||||||
|
self.conv_w = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.f_concat = ConvGnLelu(nf * 2, nf, kernel_size=3, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.f_block = RRDB(nf * 2)
|
||||||
|
|
||||||
|
self.f_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True)
|
||||||
|
self.f_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.get_g_nopadding = ImageGradientNoPadding()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x_grad = self.get_g_nopadding(x)
|
||||||
|
x = self.model[0](x)
|
||||||
|
|
||||||
|
x, block_list = self.model[1](x)
|
||||||
|
|
||||||
|
x_ori = x
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i](x)
|
||||||
|
x_fea1 = x
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i + 5](x)
|
||||||
|
x_fea2 = x
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i + 10](x)
|
||||||
|
x_fea3 = x
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
x = block_list[i + 15](x)
|
||||||
|
x_fea4 = x
|
||||||
|
|
||||||
|
x = block_list[20:](x)
|
||||||
|
# short cut
|
||||||
|
x = x_ori + x
|
||||||
|
x = self.model[2:](x)
|
||||||
|
x = self.HR_conv1_new(x)
|
||||||
|
|
||||||
|
x_b_fea = self.b_fea_conv(x_grad)
|
||||||
|
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||||
|
|
||||||
|
x_cat_1 = self.b_block_1(x_cat_1)
|
||||||
|
x_cat_1 = self.b_concat_1(x_cat_1)
|
||||||
|
|
||||||
|
x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
|
||||||
|
|
||||||
|
x_cat_2 = self.b_block_2(x_cat_2)
|
||||||
|
x_cat_2 = self.b_concat_2(x_cat_2)
|
||||||
|
|
||||||
|
x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
|
||||||
|
|
||||||
|
x_cat_3 = self.b_block_3(x_cat_3)
|
||||||
|
x_cat_3 = self.b_concat_3(x_cat_3)
|
||||||
|
|
||||||
|
x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
|
||||||
|
|
||||||
|
x_cat_4 = self.b_block_4(x_cat_4)
|
||||||
|
x_cat_4 = self.b_concat_4(x_cat_4)
|
||||||
|
|
||||||
|
x_cat_4 = self.b_LR_conv(x_cat_4)
|
||||||
|
|
||||||
|
# short cut
|
||||||
|
x_cat_4 = x_cat_4 + x_b_fea
|
||||||
|
x_branch = self.b_module(x_cat_4)
|
||||||
|
|
||||||
|
x_out_branch = self.conv_w(x_branch)
|
||||||
|
########
|
||||||
|
x_branch_d = x_branch
|
||||||
|
x_f_cat = torch.cat([x_branch_d, x], dim=1)
|
||||||
|
x_f_cat = self.f_block(x_f_cat)
|
||||||
|
x_out = self.f_concat(x_f_cat)
|
||||||
|
x_out = self.f_HR_conv0(x_out)
|
||||||
|
x_out = self.f_HR_conv1(x_out)
|
||||||
|
|
||||||
|
#########
|
||||||
|
return x_out_branch, x_out, x_grad
|
||||||
|
|
||||||
|
|
||||||
class SPSRNetSimplified(nn.Module):
|
class SPSRNetSimplified(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, upscale=4):
|
def __init__(self, in_nc, out_nc, nf, nb, upscale=4):
|
||||||
|
|
|
@ -59,6 +59,9 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
||||||
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
||||||
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
||||||
|
elif which_model == 'spsr':
|
||||||
|
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||||
|
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||||
elif which_model == 'spsr_net_improved':
|
elif which_model == 'spsr_net_improved':
|
||||||
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||||||
nb=opt_net['nb'], upscale=opt_net['scale'])
|
nb=opt_net['nb'], upscale=opt_net['scale'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user