From 5d1b4caabf70fdd272006f12f603742a1bfdb07d Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 12 May 2020 16:26:29 -0600 Subject: [PATCH] Allow noise to be injected at the generator inputs for resgen --- codes/models/archs/ResGen_arch.py | 14 ++++++++++++-- codes/models/networks.py | 3 ++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py index 8c2b9447..d087851f 100644 --- a/codes/models/archs/ResGen_arch.py +++ b/codes/models/archs/ResGen_arch.py @@ -59,14 +59,18 @@ class FixupBasicBlock(nn.Module): class FixupResNet(nn.Module): - def __init__(self, block, layers, upscale_applications=2, num_filters=64): + def __init__(self, block, layers, upscale_applications=2, num_filters=64, inject_noise=False): super(FixupResNet, self).__init__() self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling. self.inplanes = num_filters self.upscale_applications = upscale_applications + self.inject_noise = inject_noise # Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated # part of the block. - self.conv1 = nn.Conv2d(3, num_filters, kernel_size=5, stride=1, padding=2, + input_planes = 3 + if inject_noise: + input_planes = 4 + self.conv1 = nn.Conv2d(input_planes, num_filters, kernel_size=5, stride=1, padding=2, bias=False) self.bias1 = nn.Parameter(torch.zeros(1)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) @@ -119,6 +123,9 @@ class FixupResNet(nn.Module): return nn.Sequential(*layers) def forward(self, x): + if self.inject_noise: + rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype) + x = torch.cat([x, rand_feature], dim=1) x = self.conv1(x) x = self.lrelu(x + self.bias1) x = self.layer1(x) @@ -161,6 +168,9 @@ class FixupResNetV2(FixupResNet): return x def forward(self, x): + if self.inject_noise: + rand_feature = torch.randn((x.shape[0], 1) + x.shape[2:], device=x.device, dtype=x.dtype) + x = torch.cat([x, rand_feature], dim=1) x = self.conv1(x) x = self.lrelu(x + self.bias1) x = self.layer1(x) diff --git a/codes/models/networks.py b/codes/models/networks.py index 6f7c1197..9c8b4847 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -38,7 +38,8 @@ def define_G(opt): upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf']) elif which_model == 'ResGenV2': netG = ResGen_arch.fixup_resnet34_v2(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], - upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf']) + upscale_applications=opt_net['upscale_applications'], num_filters=opt_net['nf'], + inject_noise=opt_net['inject_noise']) # image corruption elif which_model == 'HighToLowResNet':