Allow resgen to have a conditional number of upsamples applied to it

This commit is contained in:
James Betker 2020-05-10 10:48:37 -06:00
parent 8969a3ce70
commit ef48e819aa
2 changed files with 9 additions and 5 deletions

View File

@ -59,10 +59,11 @@ class FixupBasicBlock(nn.Module):
class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64):
def __init__(self, block, layers, upscale_applications=2, num_filters=64):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling.
self.inplanes = num_filters
self.upscale_applications = upscale_applications
# Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated
# part of the block.
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=5, stride=1, padding=2,
@ -123,11 +124,14 @@ class FixupResNet(nn.Module):
skip_lo = self.skip1(x) + self.skip1_bias
x = self.lrelu(self.upsampler_conv(x) + self.uc_bias)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
x = self.layer2(x)
skip_med = self.skip2(x) + self.skip2_bias
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.layer2(x)
if self.upscale_applications > 1:
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
x = self.layer2(x)
x = self.final_defilter(x) + self.bias2
return x, skip_med, skip_lo

View File

@ -35,7 +35,7 @@ def define_G(opt):
interpolation_scale_factor=scale_per_step)
elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
num_filters=opt_net['nf'])
upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'])
# image corruption
elif which_model == 'HighToLowResNet':