Resnet discriminator overhaul

It's been a tough day figuring out WTH is going on with my discriminators.
It appears the raw FixUp discriminator can get into an "defective" state where
they stop trying to learn and just predict as close to "0" D_fake and D_real as
possible. In this state they provide no feedback to the generator and never
recover. Adding batch norm back in seems to fix this so it must be some sort
of parameterization error.. Should look into fixing this in the future.
This commit is contained in:
James Betker 2020-05-06 17:27:30 -06:00
parent 602f86bfa4
commit aa0305def9
2 changed files with 44 additions and 41 deletions

View File

@ -11,6 +11,11 @@ def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv5x5(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
padding=2, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
@ -20,27 +25,35 @@ def conv1x1(in_planes, out_planes, stride=1):
class FixupBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, use_bn=False, conv_create=conv3x3):
super(FixupBasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv3x3(inplanes, planes, stride)
self.conv1 = conv_create(inplanes, planes, stride)
self.bias1b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv3x3(planes, planes)
self.conv2 = conv_create(planes, planes)
self.scale = nn.Parameter(torch.ones(1))
self.bias2b = nn.Parameter(torch.zeros(1))
self.downsample = downsample
self.stride = stride
self.use_bn = use_bn
if use_bn:
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
if self.use_bn:
out = self.bn1(out)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
if self.use_bn:
out = self.bn2(out)
out = out * self.scale + self.bias2b
if self.downsample is not None:
@ -94,26 +107,23 @@ class FixupBottleneck(nn.Module):
class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64):
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, use_bn=False):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers)
self.inplanes = num_filters
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=7, stride=2, padding=3,
bias=False)
self.bias1 = nn.Parameter(torch.zeros(1))
self.inplanes = 3
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1)
self.skip1 = nn.Conv2d(num_filters + 3, num_filters, kernel_size=5, stride=1, padding=2, bias=False)
self.skip1_bias = nn.Parameter(torch.zeros(1))
self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2)
self.skip2 = nn.Conv2d(num_filters*2 + 3, num_filters*2, kernel_size=5, stride=1, padding=2, bias=False)
self.skip2_bias = nn.Parameter(torch.zeros(1))
self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2)
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
self.layer5 = self._make_layer(block, num_filters*16, layers[4], stride=2)
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5)
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn)
# SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features.
# Therefore, level off the filter count from this block forwards.
self.layer3 = self._make_layer(block, num_filters*8, layers[3], stride=2, use_bn=use_bn)
self.layer4 = self._make_layer(block, num_filters*8, layers[4], stride=2, use_bn=use_bn)
self.bias2 = nn.Parameter(torch.zeros(1))
reduced_img_sz = int(input_img_size / 32)
self.fc1 = nn.Linear(num_filters * 16 * reduced_img_sz * reduced_img_sz, 100)
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
self.fc2 = nn.Linear(100, num_classes)
for m in self.modules():
@ -128,39 +138,31 @@ class FixupResNet(nn.Module):
nn.init.constant_(m.conv3.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
'''
elif isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)'''
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = conv1x1(self.inplanes, planes * block.expansion, stride)
def _make_layer(self, block, outplanes, blocks, stride=1, use_bn=False, conv_type=conv3x3):
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
layers.append(block(self.inplanes, self.inplanes))
downsample = None
if stride != 1 or self.inplanes != outplanes * block.expansion:
downsample = conv1x1(self.inplanes, outplanes * block.expansion, stride)
layers.append(block(self.inplanes, outplanes, stride, downsample, use_bn=use_bn, conv_create=conv_type))
self.inplanes = outplanes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
# This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input.
hi, med_skip, lo_skip = x
x, med_skip, lo_skip = x
x = self.conv1(hi)
x = self.lrelu(x + self.bias1)
x = self.layer0(x)
x = torch.cat([x, med_skip], dim=1)
x = self.layer1(x)
x = self.lrelu(self.skip1(torch.cat([x, med_skip], dim=1)) + self.skip1_bias)
x = torch.cat([x, lo_skip], dim=1)
x = self.layer2(x)
x = self.lrelu(self.skip2(torch.cat([x, lo_skip], dim=1)) + self.skip2_bias)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = x.view(x.size(0), -1)
x = self.lrelu(self.fc1(x))
@ -179,7 +181,7 @@ def fixup_resnet18(**kwargs):
def fixup_resnet34(**kwargs):
"""Constructs a Fixup-ResNet-34 model.
"""
model = FixupResNet(FixupBasicBlock, [5, 4, 3, 3, 2], **kwargs)
model = FixupResNet(FixupBasicBlock, [5, 5, 3, 3, 3], **kwargs)
return model

View File

@ -34,7 +34,8 @@ def define_G(opt):
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
interpolation_scale_factor=scale_per_step)
elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf'])
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
num_filters=opt_net['nf'])
# image corruption
elif which_model == 'HighToLowResNet':
@ -70,7 +71,7 @@ def define_D(opt):
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_passthrough':
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, use_bn=True)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD