From 6beefa6d0c2718295890f1a8b2f8fe4da64df772 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 6 Jul 2020 11:15:52 -0600 Subject: [PATCH] PixDisc - Add two more levels of losses coming from this gen at higher resolutions --- codes/models/archs/discriminator_vgg_arch.py | 60 +++++++++++++++----- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/codes/models/archs/discriminator_vgg_arch.py b/codes/models/archs/discriminator_vgg_arch.py index 02fd156e..b5746e9b 100644 --- a/codes/models/archs/discriminator_vgg_arch.py +++ b/codes/models/archs/discriminator_vgg_arch.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torchvision from models.archs.arch_util import ConvBnLelu +import torch.nn.functional as F class Discriminator_VGG_128(nn.Module): @@ -109,30 +110,61 @@ class Discriminator_VGG_PixLoss(nn.Module): self.reduce_1 = ConvBnLelu(nf * 8, nf * 4, bias=False) self.pix_loss_collapse = ConvBnLelu(nf * 4, 1, bias=False, bn=False, lelu=False) + # Pyramid network: upsample with residuals and produce losses at multiple resolutions. + self.up3_decimate = ConvBnLelu(nf * 8, nf * 8, kernel_size=3, bias=True, lelu=False) + self.up3_converge = ConvBnLelu(nf * 16, nf * 8, kernel_size=3, bias=False) + self.up3_proc = ConvBnLelu(nf * 8, nf * 8, bias=False) + self.up3_reduce = ConvBnLelu(nf * 8, nf * 4, bias=False) + self.up3_pix = ConvBnLelu(nf * 4, 1, bias=False, bn=False, lelu=False) + + self.up2_decimate = ConvBnLelu(nf * 8, nf * 4, kernel_size=1, bias=True, lelu=False) + self.up2_converge = ConvBnLelu(nf * 8, nf * 4, kernel_size=3, bias=False) + self.up2_proc = ConvBnLelu(nf * 4, nf * 4, bias=False) + self.up2_reduce = ConvBnLelu(nf * 4, nf * 2, bias=False) + self.up2_pix = ConvBnLelu(nf * 2, 1, bias=False, bn=False, lelu=False) + # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): x = x[0] - fea = self.lrelu(self.conv0_0(x)) - fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) + fea0 = self.lrelu(self.conv0_0(x)) + fea0 = self.lrelu(self.bn0_1(self.conv0_1(fea0))) - fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) - fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) + fea1 = self.lrelu(self.bn1_0(self.conv1_0(fea0))) + fea1 = self.lrelu(self.bn1_1(self.conv1_1(fea1))) - fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) - fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) + fea2 = self.lrelu(self.bn2_0(self.conv2_0(fea1))) + fea2 = self.lrelu(self.bn2_1(self.conv2_1(fea2))) - fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) - fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) + fea3 = self.lrelu(self.bn3_0(self.conv3_0(fea2))) + fea3 = self.lrelu(self.bn3_1(self.conv3_1(fea3))) - fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) - fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) - - loss = self.reduce_1(fea) - loss = self.pix_loss_collapse(loss) + fea4 = self.lrelu(self.bn4_0(self.conv4_0(fea3))) + fea4 = self.lrelu(self.bn4_1(self.conv4_1(fea4))) + loss = self.reduce_1(fea4) # Compress all of the loss values into the batch dimension. The actual loss attached to this output will # then know how to handle them. - return loss.view(-1, 1) + loss = self.pix_loss_collapse(loss).view(-1, 1) + + # And the pyramid network! + dec3 = self.up3_decimate(F.interpolate(fea4, scale_factor=2, mode="nearest")) + dec3 = torch.cat([dec3, fea3], dim=1) + dec3 = self.up3_converge(dec3) + dec3 = self.up3_proc(dec3) + loss3 = self.up3_reduce(dec3) + loss3 = self.up3_pix(loss3).view(-1, 1) + + dec2 = self.up2_decimate(F.interpolate(dec3, scale_factor=2, mode="nearest")) + dec2 = torch.cat([dec2, fea2], dim=1) + dec2 = self.up2_converge(dec2) + dec2 = self.up2_proc(dec2) + loss2 = self.up2_reduce(dec2) + loss2 = self.up2_pix(loss2).view(-1, 1) + + # "Weight" all losses the same by repeating the LR losses to the HR dim. + combined_losses = torch.cat([loss.repeat((16, 1)), loss3.repeat((4, 1)), loss2]) + + return combined_losses.view(-1, 1)