diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py index ea597f64..8c2b9447 100644 --- a/codes/models/archs/ResGen_arch.py +++ b/codes/models/archs/ResGen_arch.py @@ -78,19 +78,20 @@ class FixupResNet(nn.Module): # convs which are intended to repair artifacts caused by 2x interpolation. # This core layer should by itself accomplish 2x super-resolution. We use it in repeat to do the # requested SR. - nf2 = int(num_filters/4) + self.nf2 = int(num_filters/4) # This part isn't repeated. It de-filters the output from the previous step to fit the filter size used in the # upsampler-conv. - self.upsampler_conv = nn.Conv2d(num_filters, nf2, kernel_size=3, stride=1, padding=1, bias=False) + self.upsampler_conv = nn.Conv2d(num_filters, self.nf2, kernel_size=3, stride=1, padding=1, bias=False) self.uc_bias = nn.Parameter(torch.zeros(1)) - self.inplanes = nf2 + self.inplanes = self.nf2 - # This is the repeated part. - self.layer2 = self._make_layer(block, int(nf2), layers[1], stride=1, conv_type=conv5x5) - self.skip2 = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=False) - self.skip2_bias = nn.Parameter(torch.zeros(1)) + if layers[1] > 0: + # This is the repeated part. + self.layer2 = self._make_layer(block, int(self.nf2), layers[1], stride=1, conv_type=conv5x5) + self.skip2 = nn.Conv2d(self.nf2, 3, kernel_size=5, stride=1, padding=2, bias=False) + self.skip2_bias = nn.Parameter(torch.zeros(1)) - self.final_defilter = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=True) + self.final_defilter = nn.Conv2d(self.nf2, 3, kernel_size=5, stride=1, padding=2, bias=True) self.bias2 = nn.Parameter(torch.zeros(1)) for m in self.modules(): @@ -124,9 +125,12 @@ class FixupResNet(nn.Module): skip_lo = self.skip1(x) + self.skip1_bias x = self.lrelu(self.upsampler_conv(x) + self.uc_bias) - x = F.interpolate(x, scale_factor=2.0, mode='nearest') - x = self.layer2(x) - skip_med = self.skip2(x) + self.skip2_bias + if self.upscale_applications > 0: + x = F.interpolate(x, scale_factor=2.0, mode='nearest') + x = self.layer2(x) + skip_med = self.skip2(x) + self.skip2_bias + else: + skip_med = skip_lo if self.upscale_applications > 1: x = F.interpolate(x, scale_factor=2.0, mode='nearest') @@ -135,11 +139,59 @@ class FixupResNet(nn.Module): x = self.final_defilter(x) + self.bias2 return x, skip_med, skip_lo +class FixupResNetV2(FixupResNet): + def __init__(self, **kwargs): + super(FixupResNetV2, self).__init__(**kwargs) + # Use one unified filter-to-image stack, not the previous skip stacks. + self.skip1 = None + self.skip1_bias = None + self.skip2 = None + self.skip2_bias = None + # The new filter-to-image stack will be 2 conv layers deep, not 1. + self.final_process = nn.Conv2d(self.nf2, self.nf2, kernel_size=5, stride=1, padding=2, bias=True) + self.bias2 = nn.Parameter(torch.zeros(1)) + self.fp_bn = nn.BatchNorm2d(self.nf2) + self.final_defilter = nn.Conv2d(self.nf2, 3, kernel_size=3, stride=1, padding=1, bias=True) + self.bias3 = nn.Parameter(torch.zeros(1)) + + def filter_to_image(self, filter): + x = self.final_process(filter) + self.bias2 + x = self.lrelu(self.fp_bn(x)) + x = self.final_defilter(x) + self.bias3 + return x + + def forward(self, x): + x = self.conv1(x) + x = self.lrelu(x + self.bias1) + x = self.layer1(x) + x = self.lrelu(self.upsampler_conv(x) + self.uc_bias) + + skip_lo = self.filter_to_image(x) + if self.upscale_applications > 0: + x = F.interpolate(x, scale_factor=2.0, mode='nearest') + x = self.layer2(x) + skip_med = self.filter_to_image(x) + + if self.upscale_applications > 1: + x = F.interpolate(x, scale_factor=2.0, mode='nearest') + x = self.layer2(x) + + x = self.filter_to_image(x) + return x, skip_med, skip_lo + def fixup_resnet34(nb_denoiser=20, nb_upsampler=10, **kwargs): """Constructs a Fixup-ResNet-34 model. """ model = FixupResNet(FixupBasicBlock, [nb_denoiser, nb_upsampler], **kwargs) return model +def fixup_resnet34_v2(nb_denoiser=20, nb_upsampler=10, **kwargs): + """Constructs a Fixup-ResNet-34 model. + """ + kwargs['block'] = FixupBasicBlock + kwargs['layers'] = [nb_denoiser, nb_upsampler] + model = FixupResNetV2(**kwargs) + return model -__all__ = ['FixupResNet', 'fixup_resnet34'] \ No newline at end of file + +__all__ = ['FixupResNet', 'fixup_resnet34', 'fixup_resnet34_v2'] \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 2a75014d..6f7c1197 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -36,6 +36,9 @@ def define_G(opt): elif which_model == 'ResGen': netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf']) + elif which_model == 'ResGenV2': + netG = ResGen_arch.fixup_resnet34_v2(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], + upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf']) # image corruption elif which_model == 'HighToLowResNet':