Add u-net discriminator with feature output

This commit is contained in:
James Betker 2020-07-16 10:10:09 -06:00
parent 0c4c388e15
commit 8d061a2687
6 changed files with 100 additions and 7 deletions

View File

@ -271,7 +271,7 @@ class SRGANModel(BaseModel):
# it should target this value. # it should target this value.
if self.l_gan_w > 0: if self.l_gan_w > 0:
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan': if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
pred_g_fake = self.netD(fake_GenOut) pred_g_fake = self.netD(fake_GenOut)
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
elif self.opt['train']['gan_type'] == 'ragan': elif self.opt['train']['gan_type'] == 'ragan':
@ -324,6 +324,17 @@ class SRGANModel(BaseModel):
# Apply noise to the inputs to slow discriminator convergence. # Apply noise to the inputs to slow discriminator convergence.
var_ref = var_ref + noise var_ref = var_ref + noise
fake_H = fake_H + noise fake_H = fake_H + noise
l_d_fea_real = torch.zeros(1)
l_d_fea_fake = torch.zeros(1)
if self.opt['train']['gan_type'] == 'pixgan_fea':
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
disc_fea_scale = .5
_, fea_real = self.netD(var_ref, output_feature_vector=True)
actual_fea = self.netF(var_ref)
l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor
_, fea_fake = self.netD(fake_H, output_feature_vector=True)
actual_fea = self.netF(fake_H)
l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor
if self.opt['train']['gan_type'] == 'gan': if self.opt['train']['gan_type'] == 'gan':
# need to forward and backward separately, since batch norm statistics differ # need to forward and backward separately, since batch norm statistics differ
# real # real
@ -338,7 +349,7 @@ class SRGANModel(BaseModel):
l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
if self.opt['train']['gan_type'] == 'pixgan': if 'pixgan' in self.opt['train']['gan_type']:
# randomly determine portions of the image to swap to keep the discriminator honest. # randomly determine portions of the image to swap to keep the discriminator honest.
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
@ -379,12 +390,14 @@ class SRGANModel(BaseModel):
pred_d_real = self.netD(var_ref) pred_d_real = self.netD(var_ref)
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
l_d_real_log = l_d_real * self.mega_batch_factor l_d_real_log = l_d_real * self.mega_batch_factor
l_d_real += l_d_fea_real
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled: with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
l_d_real_scaled.backward() l_d_real_scaled.backward()
# fake # fake
pred_d_fake = self.netD(fake_H) pred_d_fake = self.netD(fake_H)
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
l_d_fake_log = l_d_fake * self.mega_batch_factor l_d_fake_log = l_d_fake * self.mega_batch_factor
l_d_fake += l_d_fea_fake
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
@ -470,6 +483,10 @@ class SRGANModel(BaseModel):
if self.l_gan_w > 0 and step > self.G_warmup: if self.l_gan_w > 0 and step > self.G_warmup:
self.add_log_entry('l_d_real', l_d_real_log.item()) self.add_log_entry('l_d_real', l_d_real_log.item())
self.add_log_entry('l_d_fake', l_d_fake_log.item()) self.add_log_entry('l_d_fake', l_d_fake_log.item())
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach())) self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real)) self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))

View File

@ -237,3 +237,72 @@ class Discriminator_UNet(nn.Module):
def pixgan_parameters(self): def pixgan_parameters(self):
return 3, 4 return 3, 4
class Discriminator_UNet_FeaOut(nn.Module):
def __init__(self, in_nc, nf):
super(Discriminator_UNet_FeaOut, self).__init__()
# [64, 128, 128]
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
# [64, 64, 64]
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
# [128, 32, 32]
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
# [256, 16, 16]
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
# [512, 8, 8]
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
def forward(self, x, output_feature_vector=False):
fea0 = self.conv0_0(x)
fea0 = self.conv0_1(fea0)
fea1 = self.conv1_0(fea0)
fea1 = self.conv1_1(fea1)
fea2 = self.conv2_0(fea1)
fea2 = self.conv2_1(fea2)
fea3 = self.conv3_0(fea2)
fea3 = self.conv3_1(fea3)
feat = self.conv4_0(fea3)
fea4 = self.conv4_1(feat)
# And the pyramid network!
u1 = self.up1(fea4, fea3)
loss1 = self.collapse1(self.proc1(u1))
u2 = self.up2(u1, fea2)
loss2 = self.collapse2(self.proc2(u2))
u3 = self.up3(u2, fea1)
loss3 = self.collapse3(self.proc3(u3))
res = loss3.shape[2:]
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
# then know how to handle them.
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
F.interpolate(loss2, scale_factor=2),
F.interpolate(loss3, scale_factor=1)], dim=1)
if output_feature_vector:
return combined_losses.view(-1, 1), feat
else:
return combined_losses.view(-1, 1)
def pixgan_parameters(self):
return 3, 4

View File

@ -23,7 +23,7 @@ class GANLoss(nn.Module):
self.real_label_val = real_label_val self.real_label_val = real_label_val
self.fake_label_val = fake_label_val self.fake_label_val = fake_label_val
if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan': if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan' or self.gan_type == "pixgan_fea":
self.loss = nn.BCEWithLogitsLoss() self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan': elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss() self.loss = nn.MSELoss()
@ -46,7 +46,7 @@ class GANLoss(nn.Module):
return torch.empty_like(input).fill_(self.fake_label_val) return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real): def forward(self, input, target_is_real):
if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool): if 'pixgan' in self.gan_type and not isinstance(target_is_real, bool):
target_label = target_is_real target_label = target_is_real
else: else:
target_label = self.get_target_label(input, target_is_real) target_label = self.get_target_label(input, target_is_real)

View File

@ -122,6 +122,8 @@ def define_D(opt):
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf']) netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet": elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
else: else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD return netD

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_dual_srg.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_fdisc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)

View File

@ -3,6 +3,7 @@ from torch import nn
import models.archs.SRG1_arch as srg1 import models.archs.SRG1_arch as srg1
import models.archs.SwitchedResidualGenerator_arch as srg import models.archs.SwitchedResidualGenerator_arch as srg
import models.archs.NestedSwitchGenerator as nsg import models.archs.NestedSwitchGenerator as nsg
import models.archs.discriminator_vgg_arch as disc
import functools import functools
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax] blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
@ -93,6 +94,7 @@ if __name__ == "__main__":
torch.randn(1, 3, 64, 64), torch.randn(1, 3, 64, 64),
device='cuda') device='cuda')
''' '''
'''
test_stability(functools.partial(srg.DualOutputSRG, test_stability(functools.partial(srg.DualOutputSRG,
switch_depth=4, switch_depth=4,
switch_filters=64, switch_filters=64,
@ -105,7 +107,7 @@ if __name__ == "__main__":
upsample_factor=4), upsample_factor=4),
torch.randn(1, 3, 32, 32), torch.randn(1, 3, 32, 32),
device='cpu') device='cpu')
'''
''' '''
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator, test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
switch_filters=[32,32,32,32], switch_filters=[32,32,32,32],
@ -126,3 +128,6 @@ if __name__ == "__main__":
torch.randn(1, 3, 64, 64), torch.randn(1, 3, 64, 64),
device='cuda') device='cuda')
''' '''
test_stability(functools.partial(disc.Discriminator_UNet_FeaOut, 3, 64),
torch.randn(1,3,128,128),
device='cpu')