forked from mrq/DL-Art-School
Add new referencing discriminator
Also extend the way losses work so that you can pass parameters into the discriminator from the config file
This commit is contained in:
parent
9e5aa166de
commit
313424d7b5
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
|
||||||
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
|
from models.archs.arch_util import ConvBnLelu, ConvGnLelu, ExpansionBlock, ConvGnSilu
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from models.archs.SwitchedResidualGenerator_arch import gather_2d
|
||||||
|
|
||||||
|
|
||||||
class Discriminator_VGG_128(nn.Module):
|
class Discriminator_VGG_128(nn.Module):
|
||||||
|
@ -411,4 +411,97 @@ class Discriminator_UNet_FeaOut(nn.Module):
|
||||||
return combined_losses.view(-1, 1)
|
return combined_losses.view(-1, 1)
|
||||||
|
|
||||||
def pixgan_parameters(self):
|
def pixgan_parameters(self):
|
||||||
return 1, 4
|
return 1, 4
|
||||||
|
|
||||||
|
|
||||||
|
class Vgg128GnHead(nn.Module):
|
||||||
|
def __init__(self, in_nc, nf, depth=5):
|
||||||
|
super(Vgg128GnHead, self).__init__()
|
||||||
|
assert depth == 4 or depth == 5 # Nothing stopping others from being implemented, just not done yet.
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
# [64, 128, 128]
|
||||||
|
self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
|
||||||
|
self.bn0_1 = nn.GroupNorm(8, nf, affine=True)
|
||||||
|
# [64, 64, 64]
|
||||||
|
self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
|
||||||
|
self.bn1_0 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||||
|
self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
|
||||||
|
self.bn1_1 = nn.GroupNorm(8, nf * 2, affine=True)
|
||||||
|
# [128, 32, 32]
|
||||||
|
self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
|
||||||
|
self.bn2_0 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||||
|
self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
|
||||||
|
self.bn2_1 = nn.GroupNorm(8, nf * 4, affine=True)
|
||||||
|
# [256, 16, 16]
|
||||||
|
self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn3_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn3_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
if depth > 4:
|
||||||
|
# [512, 8, 8]
|
||||||
|
self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
|
||||||
|
self.bn4_0 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
|
||||||
|
self.bn4_1 = nn.GroupNorm(8, nf * 8, affine=True)
|
||||||
|
|
||||||
|
# activation function
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.lrelu(self.conv0_0(x))
|
||||||
|
fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
|
||||||
|
|
||||||
|
fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
|
||||||
|
|
||||||
|
if self.depth > 4:
|
||||||
|
fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
|
||||||
|
fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
|
||||||
|
return fea
|
||||||
|
|
||||||
|
|
||||||
|
class RefDiscriminatorVgg128(nn.Module):
|
||||||
|
# input_img_factor = multiplier to support images over 128x128. Only certain factors are supported.
|
||||||
|
def __init__(self, in_nc, nf, input_img_factor=1):
|
||||||
|
super(RefDiscriminatorVgg128, self).__init__()
|
||||||
|
|
||||||
|
# activation function
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
self.feature_head = Vgg128GnHead(in_nc, nf)
|
||||||
|
self.ref_head = Vgg128GnHead(in_nc+1, nf, depth=4)
|
||||||
|
final_nf = nf * 8
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(int(final_nf * 4 * input_img_factor * 4 * input_img_factor), 512)
|
||||||
|
self.ref_linear = nn.Linear(nf * 8, 128)
|
||||||
|
|
||||||
|
self.output_linears = nn.Sequential(
|
||||||
|
nn.Linear(128+512, 512),
|
||||||
|
self.lrelu,
|
||||||
|
nn.Linear(512, 256),
|
||||||
|
self.lrelu,
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
self.lrelu,
|
||||||
|
nn.Linear(128, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, ref, ref_center_point):
|
||||||
|
ref = self.ref_head(ref)
|
||||||
|
ref_center_point = ref_center_point // 16
|
||||||
|
ref_vector = gather_2d(ref, ref_center_point)
|
||||||
|
ref_vector = self.ref_linear(ref_vector)
|
||||||
|
|
||||||
|
fea = self.feature_head(x)
|
||||||
|
fea = fea.contiguous().view(fea.size(0), -1)
|
||||||
|
fea = self.lrelu(self.linear1(fea))
|
||||||
|
|
||||||
|
out = self.output_linears(torch.cat([fea, ref_vector], dim=1))
|
||||||
|
return out
|
|
@ -100,6 +100,8 @@ def define_D_net(opt_net, img_sz=None, wrap=False):
|
||||||
final_temperature_step=opt_net['final_temperature_step'])
|
final_temperature_step=opt_net['final_temperature_step'])
|
||||||
elif which_model == "cross_compare_vgg128":
|
elif which_model == "cross_compare_vgg128":
|
||||||
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
|
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
|
||||||
|
elif which_model == "discriminator_refvgg":
|
||||||
|
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
|
@ -20,6 +20,15 @@ def create_generator_loss(opt_loss, env):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
|
||||||
|
def extract_params_from_state(params, state):
|
||||||
|
if isinstance(params, list) or isinstance(params, tuple):
|
||||||
|
p = [state[r] for r in params]
|
||||||
|
else:
|
||||||
|
p = [state[params]]
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
class ConfigurableLoss(nn.Module):
|
class ConfigurableLoss(nn.Module):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(ConfigurableLoss, self).__init__()
|
super(ConfigurableLoss, self).__init__()
|
||||||
|
@ -99,17 +108,15 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea', 'crossgan', 'crossgan_lrref']:
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
if self.opt['gan_type'] == 'crossgan':
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||||
pred_g_fake = netD(state[self.opt['fake']], state['lq_fullsize_ref'])
|
pred_g_fake = netD(*fake)
|
||||||
elif self.opt['gan_type'] == 'crossgan_lrref':
|
|
||||||
pred_g_fake = netD(state[self.opt['fake']], state['lq'])
|
|
||||||
else:
|
|
||||||
pred_g_fake = netD(state[self.opt['fake']])
|
|
||||||
return self.criterion(pred_g_fake, True)
|
return self.criterion(pred_g_fake, True)
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
pred_d_real = netD(state[self.opt['real']]).detach()
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
pred_g_fake = netD(state[self.opt['fake']])
|
real = [r.detach() for r in real]
|
||||||
|
pred_d_real = netD(*real).detach()
|
||||||
|
pred_g_fake = netD(*fake)
|
||||||
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
else:
|
else:
|
||||||
|
@ -124,34 +131,19 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
self.metrics = []
|
||||||
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
|
fake = [f.detach() for f in fake]
|
||||||
|
d_real = net(*real)
|
||||||
|
d_fake = net(*fake)
|
||||||
|
|
||||||
if self.opt['gan_type'] == 'crossgan':
|
|
||||||
d_real = net(state[self.opt['real']], state['lq_fullsize_ref'])
|
|
||||||
d_fake = net(state[self.opt['fake']].detach(), state['lq_fullsize_ref'])
|
|
||||||
mismatched_lq = torch.roll(state['lq_fullsize_ref'], shifts=1, dims=0)
|
|
||||||
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
|
||||||
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
|
||||||
elif self.opt['gan_type'] == 'crossgan_lrref':
|
|
||||||
d_real = net(state[self.opt['real']], state['lq'])
|
|
||||||
d_fake = net(state[self.opt['fake']].detach(), state['lq'])
|
|
||||||
mismatched_lq = torch.roll(state['lq'], shifts=1, dims=0)
|
|
||||||
d_mismatch_real = net(state[self.opt['real']], mismatched_lq)
|
|
||||||
d_mismatch_fake = net(state[self.opt['fake']].detach(), mismatched_lq)
|
|
||||||
else:
|
|
||||||
d_real = net(state[self.opt['real']])
|
|
||||||
d_fake = net(state[self.opt['fake']].detach())
|
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'crossgan', 'crossgan_lrref']:
|
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||||
l_real = self.criterion(d_real, True)
|
l_real = self.criterion(d_real, True)
|
||||||
l_fake = self.criterion(d_fake, False)
|
l_fake = self.criterion(d_fake, False)
|
||||||
l_total = l_real + l_fake
|
l_total = l_real + l_fake
|
||||||
if 'crossgan' in self.opt['gan_type']:
|
|
||||||
l_mreal = self.criterion(d_mismatch_real, False)
|
|
||||||
l_mfake = self.criterion(d_mismatch_fake, False)
|
|
||||||
l_total += l_mreal + l_mfake
|
|
||||||
self.metrics.append(("l_mismatch", l_mfake + l_mreal))
|
|
||||||
return l_total
|
return l_total
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||||
|
|
Loading…
Reference in New Issue
Block a user