From 8202ee72b9d1cbdd198a596d3dba6fc12a3112ae Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 27 Oct 2020 11:00:38 -0600 Subject: [PATCH] Re-add original SPSR_arch --- codes/models/archs/SPSR_arch.py | 134 ++++++++++++++++++++++++++++++++ codes/models/networks.py | 3 + 2 files changed, 137 insertions(+) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 1937b0b5..6f60c324 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -69,6 +69,7 @@ class ImageGradientNoPadding(nn.Module): def forward(self, x): + x = x.float() x_list = [] for i in range(x.shape[1]): x_i = x[:, i] @@ -86,6 +87,139 @@ class ImageGradientNoPadding(nn.Module): # Generator #################### +class SPSRNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ + act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): + super(SPSRNet, self).__init__() + + n_upscale = int(math.log(upscale, 2)) + + if upscale == 3: + n_upscale = 1 + + fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + rb_blocks = [RRDB(nf) for _ in range(nb)] + + LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + + upsample_block = UpconvBlock + if upscale == 3: + upsampler = upsample_block(nf, nf, activation=True) + else: + upsampler = [upsample_block(nf, nf, activation=True) for _ in range(n_upscale)] + + self.HR_conv0_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True) + self.HR_conv1_new = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + + self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \ + *upsampler, self.HR_conv0_new) + + self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + + self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) + self.b_block_1 = RRDB(nf * 2) + + self.b_concat_2 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) + self.b_block_2 = RRDB(nf * 2) + + self.b_concat_3 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) + self.b_block_3 = RRDB(nf * 2) + + self.b_concat_4 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) + self.b_block_4 = RRDB(nf * 2) + + self.b_LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + + if upscale == 3: + b_upsampler = UpconvBlock(nf, nf, activation=True) + else: + b_upsampler = [UpconvBlock(nf, nf, activation=True) for _ in range(n_upscale)] + + b_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True) + b_HR_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) + + self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1) + + self.conv_w = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False) + + self.f_concat = ConvGnLelu(nf * 2, nf, kernel_size=3, norm=False, activation=False) + + self.f_block = RRDB(nf * 2) + + self.f_HR_conv0 = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True) + self.f_HR_conv1 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False) + + self.get_g_nopadding = ImageGradientNoPadding() + + def forward(self, x): + + x_grad = self.get_g_nopadding(x) + x = self.model[0](x) + + x, block_list = self.model[1](x) + + x_ori = x + for i in range(5): + x = block_list[i](x) + x_fea1 = x + + for i in range(5): + x = block_list[i + 5](x) + x_fea2 = x + + for i in range(5): + x = block_list[i + 10](x) + x_fea3 = x + + for i in range(5): + x = block_list[i + 15](x) + x_fea4 = x + + x = block_list[20:](x) + # short cut + x = x_ori + x + x = self.model[2:](x) + x = self.HR_conv1_new(x) + + x_b_fea = self.b_fea_conv(x_grad) + x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) + + x_cat_1 = self.b_block_1(x_cat_1) + x_cat_1 = self.b_concat_1(x_cat_1) + + x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1) + + x_cat_2 = self.b_block_2(x_cat_2) + x_cat_2 = self.b_concat_2(x_cat_2) + + x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1) + + x_cat_3 = self.b_block_3(x_cat_3) + x_cat_3 = self.b_concat_3(x_cat_3) + + x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1) + + x_cat_4 = self.b_block_4(x_cat_4) + x_cat_4 = self.b_concat_4(x_cat_4) + + x_cat_4 = self.b_LR_conv(x_cat_4) + + # short cut + x_cat_4 = x_cat_4 + x_b_fea + x_branch = self.b_module(x_cat_4) + + x_out_branch = self.conv_w(x_branch) + ######## + x_branch_d = x_branch + x_f_cat = torch.cat([x_branch_d, x], dim=1) + x_f_cat = self.f_block(x_f_cat) + x_out = self.f_concat(x_f_cat) + x_out = self.f_HR_conv0(x_out) + x_out = self.f_HR_conv1(x_out) + + ######### + return x_out_branch, x_out, x_grad + class SPSRNetSimplified(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, upscale=4): diff --git a/codes/models/networks.py b/codes/models/networks.py index fbb67440..10da1cd9 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -59,6 +59,9 @@ def define_G(opt, net_key='network_G', scale=None): initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'], heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'], upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise']) + elif which_model == 'spsr': + netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], + nb=opt_net['nb'], upscale=opt_net['scale']) elif which_model == 'spsr_net_improved': netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])