diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index 7be2961f..4295e908 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -139,6 +139,30 @@ class ResidualBlock_noBN(nn.Module): return identity + out +class ResidualBlockGN(nn.Module): + '''Residual block with GroupNorm + ---Conv-GN-ReLU-Conv-+- + |________________| + ''' + + def __init__(self, nf=64): + super(ResidualBlockGN, self).__init__() + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.BN1 = nn.GroupNorm(8, nf) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.BN2 = nn.GroupNorm(8, nf) + + # initialization + initialize_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.lrelu(self.BN1(self.conv1(x))) + out = self.BN2(self.conv2(out)) + return identity + out + + def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): """Warp an image or feature map with optical flow Args: diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 3570523a..731f2ffc 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from models.archs.RRDBNet_arch import RRDB, RRDBWithBypass -from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu +from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu, ResidualBlockGN import torch.nn.functional as F from models.archs.SwitchedResidualGenerator_arch import gather_2d from models.archs.pyramid_arch import Pyramid @@ -666,10 +666,15 @@ class PyramidDiscriminator(nn.Module): def __init__(self, in_nc, nf, block=ConvGnLelu): super(PyramidDiscriminator, self).__init__() self.initial_conv = block(in_nc, nf, kernel_size=3, stride=2, bias=True, norm=False, activation=True) - self.top_proc = nn.Sequential(*[ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False, norm=True, activation=True)]) + self.top_proc = nn.Sequential(*[ResidualBlockGN(nf), + ResidualBlockGN(nf), + ResidualBlockGN(nf)]) self.pyramid = Pyramid(nf, depth=3, processing_convs_per_layer=2, processing_at_point=2, scale_per_level=1.5, norm=True, return_outlevels=False) - self.bottom_proc = nn.Sequential(*[ + self.bottom_proc = nn.Sequential(*[ResidualBlockGN(nf), + ResidualBlockGN(nf), + ResidualBlockGN(nf), + ResidualBlockGN(nf), ConvGnLelu(nf, nf // 2, kernel_size=1, activation=True, norm=True, bias=True), ConvGnLelu(nf // 2, nf // 4, kernel_size=1, activation=True, norm=True, bias=True), ConvGnLelu(nf // 4, 1, activation=False, norm=False, bias=True)])