forked from mrq/DL-Art-School
Implement unet disc
The latest discriminator architecture was already pretty much a unet. This one makes that official and uses shared layers. It also upsamples one additional time and throws out the lowest upsampling result. The intent is to delete the old vgg pixdisc, but I'll keep it around for a bit since I'm still trying out a few models with it.
This commit is contained in:
parent
812c684f7d
commit
0b7193392f
|
@ -339,15 +339,11 @@ class SRGANModel(BaseModel):
|
|||
l_d_fake_scaled.backward()
|
||||
if self.opt['train']['gan_type'] == 'pixgan':
|
||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||
|
||||
# We're making some assumptions about the underlying pixel-discriminator here. This is a
|
||||
# necessary evil for now, but if this turns out well we might want to make this configurable.
|
||||
PIXDISC_CHANNELS = 3
|
||||
PIXDISC_OUTPUT_REDUCTION = 8
|
||||
disc_output_shape = (var_ref[0].shape[0], PIXDISC_CHANNELS, var_ref[0].shape[2] // PIXDISC_OUTPUT_REDUCTION, var_ref[0].shape[3] // PIXDISC_OUTPUT_REDUCTION)
|
||||
pixdisc_channels, pixdisc_output_reduction = self.netD.pixgan_parameters()
|
||||
disc_output_shape = (var_ref[0].shape[0], pixdisc_channels, var_ref[0].shape[2] // pixdisc_output_reduction, var_ref[0].shape[3] // pixdisc_output_reduction)
|
||||
b, _, w, h = var_ref[0].shape
|
||||
real = torch.ones((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
|
||||
fake = torch.zeros((b, PIXDISC_CHANNELS, w, h), device=var_ref[0].device)
|
||||
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref[0].device)
|
||||
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref[0].device)
|
||||
SWAP_MAX_DIM = w // 4
|
||||
SWAP_MIN_DIM = 16
|
||||
assert SWAP_MAX_DIM > 0
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu
|
||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
@ -169,3 +169,74 @@ class Discriminator_VGG_PixLoss(nn.Module):
|
|||
combined_losses = torch.cat([loss, loss3, loss2], dim=1)
|
||||
return combined_losses.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 3, 8
|
||||
|
||||
|
||||
class Discriminator_UNet(nn.Module):
|
||||
def __init__(self, in_nc, nf):
|
||||
super(Discriminator_VGG_PixLoss, self).__init__()
|
||||
# [64, 128, 128]
|
||||
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, gn=False)
|
||||
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||
# [64, 64, 64]
|
||||
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
|
||||
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||
# [128, 32, 32]
|
||||
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||
# [256, 16, 16]
|
||||
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
# [512, 8, 8]
|
||||
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
|
||||
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.up1 = ExpansionBlock(nf * 8, block=ConvGnLelu)
|
||||
self.proc1 = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||
self.collapse1 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up2 = ExpansionBlock(nf * 4, block=ConvGnLelu)
|
||||
self.proc2 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||
self.collapse2 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
self.up3 = ExpansionBlock(nf * 2, block=ConvGnLelu)
|
||||
self.proc3 = ConvGnLelu(nf, nf, bias=False)
|
||||
self.collapse3 = ConvGnLelu(nf, 1, bias=True, norm=False, activation=False)
|
||||
|
||||
def forward(self, x, flatten=True):
|
||||
x = x[0]
|
||||
fea0 = self.conv0_0(x)
|
||||
fea0 = self.conv0_1(fea0)
|
||||
|
||||
fea1 = self.conv1_0(fea0)
|
||||
fea1 = self.conv1_1(fea1)
|
||||
|
||||
fea2 = self.conv2_0(fea1)
|
||||
fea2 = self.conv2_1(fea2)
|
||||
|
||||
fea3 = self.conv3_0(fea2)
|
||||
fea3 = self.conv3_1(fea3)
|
||||
|
||||
fea4 = self.conv4_0(fea3)
|
||||
fea4 = self.conv4_1(fea4)
|
||||
|
||||
# And the pyramid network!
|
||||
u1 = self.up1(fea4, fea3)
|
||||
loss1 = self.collapse1(self.proc1(u1))
|
||||
u2 = self.up2(u1, fea2)
|
||||
loss2 = self.collapse2(self.proc2(u2))
|
||||
u3 = self.up3(u2, fea1)
|
||||
loss3 = self.collapse3(self.proc3(u3))
|
||||
res = loss3.shape[2:]
|
||||
|
||||
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||
# then know how to handle them.
|
||||
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
||||
F.interpolate(loss2, scale_factor=2),
|
||||
F.interpolate(loss3, scale_factor=1)], dim=1)
|
||||
return combined_losses.view(-1, 1)
|
||||
|
||||
def pixgan_parameters(self):
|
||||
return 3, 4
|
||||
|
||||
|
|
|
@ -111,6 +111,8 @@ def define_D(opt):
|
|||
disable_passthrough=opt_net['disable_passthrough'])
|
||||
elif which_model == 'discriminator_pix':
|
||||
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
elif which_model == "discriminator_unet":
|
||||
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||
return netD
|
||||
|
|
Loading…
Reference in New Issue
Block a user