diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py index 8783e252..ea597f64 100644 --- a/codes/models/archs/ResGen_arch.py +++ b/codes/models/archs/ResGen_arch.py @@ -59,10 +59,11 @@ class FixupBasicBlock(nn.Module): class FixupResNet(nn.Module): - def __init__(self, block, layers, num_filters=64): + def __init__(self, block, layers, upscale_applications=2, num_filters=64): super(FixupResNet, self).__init__() self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling. self.inplanes = num_filters + self.upscale_applications = upscale_applications # Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated # part of the block. self.conv1 = nn.Conv2d(3, num_filters, kernel_size=5, stride=1, padding=2, @@ -123,11 +124,14 @@ class FixupResNet(nn.Module): skip_lo = self.skip1(x) + self.skip1_bias x = self.lrelu(self.upsampler_conv(x) + self.uc_bias) - x = F.interpolate(x, scale_factor=2, mode='nearest') + x = F.interpolate(x, scale_factor=2.0, mode='nearest') x = self.layer2(x) skip_med = self.skip2(x) + self.skip2_bias - x = F.interpolate(x, scale_factor=2, mode='nearest') - x = self.layer2(x) + + if self.upscale_applications > 1: + x = F.interpolate(x, scale_factor=2.0, mode='nearest') + x = self.layer2(x) + x = self.final_defilter(x) + self.bias2 return x, skip_med, skip_lo diff --git a/codes/models/networks.py b/codes/models/networks.py index f1f39640..2a75014d 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -35,7 +35,7 @@ def define_G(opt): interpolation_scale_factor=scale_per_step) elif which_model == 'ResGen': netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], - num_filters=opt_net['nf']) + upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf']) # image corruption elif which_model == 'HighToLowResNet':