From 9a1c3241f5fd8737ef523d2ca28317c11e99c051 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 6 Jul 2020 20:59:59 -0600 Subject: [PATCH] Switch discriminator to groupnorm --- codes/models/archs/arch_util.py | 36 ++++++++++++++++ codes/models/archs/discriminator_vgg_arch.py | 44 ++++++++++---------- 2 files changed, 58 insertions(+), 22 deletions(-) diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 3e97b32e..afb99703 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -283,5 +283,41 @@ class ConvBnLelu(nn.Module): x = self.bn(x) if self.lelu: return self.lelu(x) + else: + return x + + +''' Convenience class with Conv->GroupNorm->LeakyReLU. Includes weight initialization and auto-padding for standard + kernel sizes. ''' +class ConvGnLelu(nn.Module): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, gn=True, bias=True, num_groups=8): + super(ConvGnLelu, self).__init__() + padding_map = {1: 0, 3: 1, 5: 2, 7: 3} + assert kernel_size in padding_map.keys() + self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size], bias=bias) + if gn: + self.gn = nn.GroupNorm(num_groups, filters_out) + else: + self.gn = None + if lelu: + self.lelu = nn.LeakyReLU(negative_slope=.1) + else: + self.lelu = None + + # Init params. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, a=.1, mode='fan_out', + nonlinearity='leaky_relu' if self.lelu else 'linear') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv(x) + if self.gn: + x = self.gn(x) + if self.lelu: + return self.lelu(x) else: return x \ No newline at end of file diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index b5746e9b..4387cf89 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torchvision -from models.archs.arch_util import ConvBnLelu +from models.archs.arch_util import ConvBnLelu, ConvGnLelu import torch.nn.functional as F @@ -85,43 +85,43 @@ class Discriminator_VGG_PixLoss(nn.Module): # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) - self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + self.bn0_1 = nn.GroupNorm(8, nf, affine=True) # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) - self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) - self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True) # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) - self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) - self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True) # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) - self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True) # [512, 8, 8] self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) - self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) - self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True) - self.reduce_1 = ConvBnLelu(nf * 8, nf * 4, bias=False) - self.pix_loss_collapse = ConvBnLelu(nf * 4, 1, bias=False, bn=False, lelu=False) + self.reduce_1 = ConvGnLelu(nf * 8, nf * 4, bias=False) + self.pix_loss_collapse = ConvGnLelu(nf * 4, 1, bias=False, gn=False, lelu=False) # Pyramid network: upsample with residuals and produce losses at multiple resolutions. - self.up3_decimate = ConvBnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, lelu=False) - self.up3_converge = ConvBnLelu(nf * 16, nf * 8, kernel_size=3, bias=False) - self.up3_proc = ConvBnLelu(nf * 8, nf * 8, bias=False) - self.up3_reduce = ConvBnLelu(nf * 8, nf * 4, bias=False) - self.up3_pix = ConvBnLelu(nf * 4, 1, bias=False, bn=False, lelu=False) + self.up3_decimate = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, lelu=False) + self.up3_converge = ConvGnLelu(nf * 16, nf * 8, kernel_size=3, bias=False) + self.up3_proc = ConvGnLelu(nf * 8, nf * 8, bias=False) + self.up3_reduce = ConvGnLelu(nf * 8, nf * 4, bias=False) + self.up3_pix = ConvGnLelu(nf * 4, 1, bias=False, gn=False, lelu=False) - self.up2_decimate = ConvBnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, lelu=False) - self.up2_converge = ConvBnLelu(nf * 8, nf * 4, kernel_size=3, bias=False) - self.up2_proc = ConvBnLelu(nf * 4, nf * 4, bias=False) - self.up2_reduce = ConvBnLelu(nf * 4, nf * 2, bias=False) - self.up2_pix = ConvBnLelu(nf * 2, 1, bias=False, bn=False, lelu=False) + self.up2_decimate = ConvGnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, lelu=False) + self.up2_converge = ConvGnLelu(nf * 8, nf * 4, kernel_size=3, bias=False) + self.up2_proc = ConvGnLelu(nf * 4, nf * 4, bias=False) + self.up2_reduce = ConvGnLelu(nf * 4, nf * 2, bias=False) + self.up2_pix = ConvGnLelu(nf * 2, 1, bias=False, gn=False, lelu=False) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)