forked from mrq/DL-Art-School
Resnet discriminator overhaul
It's been a tough day figuring out WTH is going on with my discriminators. It appears the raw FixUp discriminator can get into an "defective" state where they stop trying to learn and just predict as close to "0" D_fake and D_real as possible. In this state they provide no feedback to the generator and never recover. Adding batch norm back in seems to fix this so it must be some sort of parameterization error.. Should look into fixing this in the future.
This commit is contained in:
parent
602f86bfa4
commit
aa0305def9
|
@ -11,6 +11,11 @@ def conv3x3(in_planes, out_planes, stride=1):
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
padding=1, bias=False)
|
padding=1, bias=False)
|
||||||
|
|
||||||
|
def conv5x5(in_planes, out_planes, stride=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
|
||||||
|
padding=2, bias=False)
|
||||||
|
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
"""1x1 convolution"""
|
"""1x1 convolution"""
|
||||||
|
@ -20,27 +25,35 @@ def conv1x1(in_planes, out_planes, stride=1):
|
||||||
class FixupBasicBlock(nn.Module):
|
class FixupBasicBlock(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
def __init__(self, inplanes, planes, stride=1, downsample=None, use_bn=False, conv_create=conv3x3):
|
||||||
super(FixupBasicBlock, self).__init__()
|
super(FixupBasicBlock, self).__init__()
|
||||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
self.conv1 = conv_create(inplanes, planes, stride)
|
||||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||||
self.conv2 = conv3x3(planes, planes)
|
self.conv2 = conv_create(planes, planes)
|
||||||
self.scale = nn.Parameter(torch.ones(1))
|
self.scale = nn.Parameter(torch.ones(1))
|
||||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
self.use_bn = use_bn
|
||||||
|
if use_bn:
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
|
|
||||||
out = self.conv1(x + self.bias1a)
|
out = self.conv1(x + self.bias1a)
|
||||||
|
if self.use_bn:
|
||||||
|
out = self.bn1(out)
|
||||||
out = self.lrelu(out + self.bias1b)
|
out = self.lrelu(out + self.bias1b)
|
||||||
|
|
||||||
out = self.conv2(out + self.bias2a)
|
out = self.conv2(out + self.bias2a)
|
||||||
|
if self.use_bn:
|
||||||
|
out = self.bn2(out)
|
||||||
out = out * self.scale + self.bias2b
|
out = out * self.scale + self.bias2b
|
||||||
|
|
||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
|
@ -94,26 +107,23 @@ class FixupBottleneck(nn.Module):
|
||||||
|
|
||||||
class FixupResNet(nn.Module):
|
class FixupResNet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64):
|
def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, use_bn=False):
|
||||||
super(FixupResNet, self).__init__()
|
super(FixupResNet, self).__init__()
|
||||||
self.num_layers = sum(layers)
|
self.num_layers = sum(layers)
|
||||||
self.inplanes = num_filters
|
self.inplanes = 3
|
||||||
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=7, stride=2, padding=3,
|
|
||||||
bias=False)
|
|
||||||
self.bias1 = nn.Parameter(torch.zeros(1))
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1)
|
self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||||
self.skip1 = nn.Conv2d(num_filters + 3, num_filters, kernel_size=5, stride=1, padding=2, bias=False)
|
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
|
||||||
self.skip1_bias = nn.Parameter(torch.zeros(1))
|
self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5)
|
||||||
self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2)
|
self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator.
|
||||||
self.skip2 = nn.Conv2d(num_filters*2 + 3, num_filters*2, kernel_size=5, stride=1, padding=2, bias=False)
|
self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn)
|
||||||
self.skip2_bias = nn.Parameter(torch.zeros(1))
|
# SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features.
|
||||||
self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2)
|
# Therefore, level off the filter count from this block forwards.
|
||||||
self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2)
|
self.layer3 = self._make_layer(block, num_filters*8, layers[3], stride=2, use_bn=use_bn)
|
||||||
self.layer5 = self._make_layer(block, num_filters*16, layers[4], stride=2)
|
self.layer4 = self._make_layer(block, num_filters*8, layers[4], stride=2, use_bn=use_bn)
|
||||||
self.bias2 = nn.Parameter(torch.zeros(1))
|
self.bias2 = nn.Parameter(torch.zeros(1))
|
||||||
reduced_img_sz = int(input_img_size / 32)
|
reduced_img_sz = int(input_img_size / 32)
|
||||||
self.fc1 = nn.Linear(num_filters * 16 * reduced_img_sz * reduced_img_sz, 100)
|
self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100)
|
||||||
self.fc2 = nn.Linear(100, num_classes)
|
self.fc2 = nn.Linear(100, num_classes)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
|
@ -128,39 +138,31 @@ class FixupResNet(nn.Module):
|
||||||
nn.init.constant_(m.conv3.weight, 0)
|
nn.init.constant_(m.conv3.weight, 0)
|
||||||
if m.downsample is not None:
|
if m.downsample is not None:
|
||||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||||
'''
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
nn.init.constant_(m.weight, 0)
|
|
||||||
nn.init.constant_(m.bias, 0)'''
|
|
||||||
|
|
||||||
def _make_layer(self, block, planes, blocks, stride=1):
|
|
||||||
downsample = None
|
|
||||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
||||||
downsample = conv1x1(self.inplanes, planes * block.expansion, stride)
|
|
||||||
|
|
||||||
|
def _make_layer(self, block, outplanes, blocks, stride=1, use_bn=False, conv_type=conv3x3):
|
||||||
layers = []
|
layers = []
|
||||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
||||||
self.inplanes = planes * block.expansion
|
|
||||||
for _ in range(1, blocks):
|
for _ in range(1, blocks):
|
||||||
layers.append(block(self.inplanes, planes))
|
layers.append(block(self.inplanes, self.inplanes))
|
||||||
|
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != outplanes * block.expansion:
|
||||||
|
downsample = conv1x1(self.inplanes, outplanes * block.expansion, stride)
|
||||||
|
layers.append(block(self.inplanes, outplanes, stride, downsample, use_bn=use_bn, conv_create=conv_type))
|
||||||
|
self.inplanes = outplanes * block.expansion
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input.
|
# This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input.
|
||||||
hi, med_skip, lo_skip = x
|
x, med_skip, lo_skip = x
|
||||||
|
|
||||||
x = self.conv1(hi)
|
x = self.layer0(x)
|
||||||
x = self.lrelu(x + self.bias1)
|
x = torch.cat([x, med_skip], dim=1)
|
||||||
x = self.layer1(x)
|
x = self.layer1(x)
|
||||||
x = self.lrelu(self.skip1(torch.cat([x, med_skip], dim=1)) + self.skip1_bias)
|
x = torch.cat([x, lo_skip], dim=1)
|
||||||
|
|
||||||
x = self.layer2(x)
|
x = self.layer2(x)
|
||||||
x = self.lrelu(self.skip2(torch.cat([x, lo_skip], dim=1)) + self.skip2_bias)
|
|
||||||
|
|
||||||
x = self.layer3(x)
|
x = self.layer3(x)
|
||||||
x = self.layer4(x)
|
x = self.layer4(x)
|
||||||
x = self.layer5(x)
|
|
||||||
|
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
x = self.lrelu(self.fc1(x))
|
x = self.lrelu(self.fc1(x))
|
||||||
|
@ -179,7 +181,7 @@ def fixup_resnet18(**kwargs):
|
||||||
def fixup_resnet34(**kwargs):
|
def fixup_resnet34(**kwargs):
|
||||||
"""Constructs a Fixup-ResNet-34 model.
|
"""Constructs a Fixup-ResNet-34 model.
|
||||||
"""
|
"""
|
||||||
model = FixupResNet(FixupBasicBlock, [5, 4, 3, 3, 2], **kwargs)
|
model = FixupResNet(FixupBasicBlock, [5, 5, 3, 3, 3], **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,8 @@ def define_G(opt):
|
||||||
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
|
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
|
||||||
interpolation_scale_factor=scale_per_step)
|
interpolation_scale_factor=scale_per_step)
|
||||||
elif which_model == 'ResGen':
|
elif which_model == 'ResGen':
|
||||||
netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'],
|
||||||
|
num_filters=opt_net['nf'])
|
||||||
|
|
||||||
# image corruption
|
# image corruption
|
||||||
elif which_model == 'HighToLowResNet':
|
elif which_model == 'HighToLowResNet':
|
||||||
|
@ -70,7 +71,7 @@ def define_D(opt):
|
||||||
elif which_model == 'discriminator_resnet':
|
elif which_model == 'discriminator_resnet':
|
||||||
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
||||||
elif which_model == 'discriminator_resnet_passthrough':
|
elif which_model == 'discriminator_resnet_passthrough':
|
||||||
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, use_bn=True)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
Loading…
Reference in New Issue
Block a user