Implement unet disc

The latest discriminator architecture was already pretty much a unet. This
one makes that official and uses shared layers. It also upsamples one additional
time and throws out the lowest upsampling result.

The intent is to delete the old vgg pixdisc, but I'll keep it around for a bit since
I'm still trying out a few models with it.
This commit is contained in:
James Betker 2020-07-10 16:16:03 -06:00
parent 812c684f7d
commit 0b7193392f
3 changed files with 78 additions and 9 deletions

View File

@ -339,15 +339,11 @@ class SRGANModel(BaseModel):
l_d_fake_scaled.backward()
if self.opt['train']['gan_type'] == 'pixgan':
# randomly determine portions of the image to swap to keep the discriminator honest.
# We're making some assumptions about the underlying pixel-discriminator here. This is a
# necessary evil for now, but if this turns out well we might want to make this configurable.
PIXDISC_CHANNELS = 3
PIXDISC_OUTPUT_REDUCTION = 8
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
pixdisc_channels, pixdisc_output_reduction = self.netD.pixgan_parameters()
disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction)
b, _, w, h = var_ref[0].shape
real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device)
SWAP_MAX_DIM = w // 4
SWAP_MIN_DIM = 16
assert SWAP_MAX_DIM > 0

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torchvision
from models.archs.arch_util import ConvBnLelu, ConvGnLelu
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock
import torch.nn.functional as F
@ -169,3 +169,74 @@ class Discriminator_VGG_PixLoss(nn.Module):
combined_losses = torch.cat([loss, loss3, loss2], dim=1)
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 3, 8
class Discriminator_UNet(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_VGG_PixLoss, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, gn=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.up1 = ExpansionBlock(nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.collapse1 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
self.up2 = ExpansionBlock(nf * 4, block=ConvGnLelu)
self.proc2 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse2 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
self.up3 = ExpansionBlock(nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf, nf, bias=False)
self.collapse3 = ConvGnLelu(nf, 1, bias=True, norm=False, activation=False)
def forward(self, x, flatten=True):
x = x[0]
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
fea4 = self.conv4_0(fea3)
fea4 = self.conv4_1(fea4)
# And the pyramid network!
u1 = self.up1(fea4, fea3)
loss1 = self.collapse1(self.proc1(u1))
u2 = self.up2(u1, fea2)
loss2 = self.collapse2(self.proc2(u2))
u3 = self.up3(u2, fea1)
loss3 = self.collapse3(self.proc3(u3))
res = loss3.shape[2:]
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
F.interpolate(loss2, scale_factor=2),
F.interpolate(loss3, scale_factor=1)], dim=1)
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 3, 4

View File

@ -111,6 +111,8 @@ def define_D(opt):
disable_passthrough=opt_net['disable_passthrough'])
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD