From a47a5dca43cddb6f137d14bc03c2b2f6913582fe Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 5 Jul 2020 21:49:09 -0600 Subject: [PATCH] Fix pixdisc bug --- codes/models/archs/discriminator_vgg_arch.py | 60 ++++++++++++++++++++ codes/models/networks.py | 2 + codes/train.py | 2 +- 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index fe3d3288..f99edfc5 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torchvision +from models.archs.arch_util import ConvBnLelu class Discriminator_VGG_128(nn.Module): @@ -76,3 +77,62 @@ class Discriminator_VGG_128(nn.Module): out = self.linear2(fea) return out + +class Discriminator_VGG_PixLoss(nn.Module): + def __init__(self, in_nc, nf): + super(Discriminator_VGG_PixLoss, self).__init__() + # [64, 128, 128] + self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) + self.bn0_1 = nn.BatchNorm2d(nf, affine=True) + # [64, 64, 64] + self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) + self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) + self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) + self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) + # [128, 32, 32] + self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) + self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) + self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) + self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) + # [256, 16, 16] + self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) + self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) + # [512, 8, 8] + self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) + self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) + self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) + self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) + + self.reduce_1 = ConvBnLelu(nf * 8, nf * 4, bias=False) + self.pix_loss_collapse = ConvBnLelu(nf * 4, 1, bias=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + x = x[0] + fea = self.lrelu(self.conv0_0(x)) + fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + + fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) + fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + + fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) + fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + + fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) + fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + + fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) + fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) + + loss = self.reduce_1(fea) + loss = self.pix_loss_collapse(loss) + + # Compress all of the loss values into the batch dimension. The actual loss attached to this output will + # then know how to handle them. + return loss.view(-1, 1) + diff --git a/codes/models/networks.py b/codes/models/networks.py index 927c871a..ecb61a83 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -114,6 +114,8 @@ def define_D(opt): netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, number_skips=opt_net['number_skips'], use_bn=True, disable_passthrough=opt_net['disable_passthrough']) + elif which_model == 'discriminator_pix': + netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/train.py b/codes/train.py index 37efc401..15c525f5 100644 --- a/codes/train.py +++ b/codes/train.py @@ -33,7 +33,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_div2k_srg2/train_div2k_srg2_basis.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_rrdb.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)