Add u-net discriminator with feature output
This commit is contained in:
parent
0c4c388e15
commit
8d061a2687
|
@ -271,7 +271,7 @@ class SRGANModel(BaseModel):
|
||||||
# it should target this value.
|
# it should target this value.
|
||||||
|
|
||||||
if self.l_gan_w > 0:
|
if self.l_gan_w > 0:
|
||||||
if self.opt['train']['gan_type'] == 'gan' or self.opt['train']['gan_type'] == 'pixgan':
|
if self.opt['train']['gan_type'] == 'gan' or 'pixgan' in self.opt['train']['gan_type']:
|
||||||
pred_g_fake = self.netD(fake_GenOut)
|
pred_g_fake = self.netD(fake_GenOut)
|
||||||
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
||||||
elif self.opt['train']['gan_type'] == 'ragan':
|
elif self.opt['train']['gan_type'] == 'ragan':
|
||||||
|
@ -324,6 +324,17 @@ class SRGANModel(BaseModel):
|
||||||
# Apply noise to the inputs to slow discriminator convergence.
|
# Apply noise to the inputs to slow discriminator convergence.
|
||||||
var_ref = var_ref + noise
|
var_ref = var_ref + noise
|
||||||
fake_H = fake_H + noise
|
fake_H = fake_H + noise
|
||||||
|
l_d_fea_real = torch.zeros(1)
|
||||||
|
l_d_fea_fake = torch.zeros(1)
|
||||||
|
if self.opt['train']['gan_type'] == 'pixgan_fea':
|
||||||
|
# Compute a feature loss which is added to the GAN loss computed later to guide the discriminator better.
|
||||||
|
disc_fea_scale = .5
|
||||||
|
_, fea_real = self.netD(var_ref, output_feature_vector=True)
|
||||||
|
actual_fea = self.netF(var_ref)
|
||||||
|
l_d_fea_real = self.cri_fea(fea_real, actual_fea) * disc_fea_scale / self.mega_batch_factor
|
||||||
|
_, fea_fake = self.netD(fake_H, output_feature_vector=True)
|
||||||
|
actual_fea = self.netF(fake_H)
|
||||||
|
l_d_fea_fake = self.cri_fea(fea_fake, actual_fea) * disc_fea_scale / self.mega_batch_factor
|
||||||
if self.opt['train']['gan_type'] == 'gan':
|
if self.opt['train']['gan_type'] == 'gan':
|
||||||
# need to forward and backward separately, since batch norm statistics differ
|
# need to forward and backward separately, since batch norm statistics differ
|
||||||
# real
|
# real
|
||||||
|
@ -338,7 +349,7 @@ class SRGANModel(BaseModel):
|
||||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
if self.opt['train']['gan_type'] == 'pixgan':
|
if 'pixgan' in self.opt['train']['gan_type']:
|
||||||
# randomly determine portions of the image to swap to keep the discriminator honest.
|
# randomly determine portions of the image to swap to keep the discriminator honest.
|
||||||
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
|
||||||
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
|
||||||
|
@ -379,12 +390,14 @@ class SRGANModel(BaseModel):
|
||||||
pred_d_real = self.netD(var_ref)
|
pred_d_real = self.netD(var_ref)
|
||||||
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
|
l_d_real = self.cri_gan(pred_d_real, real) / self.mega_batch_factor
|
||||||
l_d_real_log = l_d_real * self.mega_batch_factor
|
l_d_real_log = l_d_real * self.mega_batch_factor
|
||||||
|
l_d_real += l_d_fea_real
|
||||||
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
with amp.scale_loss(l_d_real, self.optimizer_D, loss_id=2) as l_d_real_scaled:
|
||||||
l_d_real_scaled.backward()
|
l_d_real_scaled.backward()
|
||||||
# fake
|
# fake
|
||||||
pred_d_fake = self.netD(fake_H)
|
pred_d_fake = self.netD(fake_H)
|
||||||
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
|
l_d_fake = self.cri_gan(pred_d_fake, fake) / self.mega_batch_factor
|
||||||
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
l_d_fake_log = l_d_fake * self.mega_batch_factor
|
||||||
|
l_d_fake += l_d_fea_fake
|
||||||
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
|
||||||
l_d_fake_scaled.backward()
|
l_d_fake_scaled.backward()
|
||||||
|
|
||||||
|
@ -470,6 +483,10 @@ class SRGANModel(BaseModel):
|
||||||
if self.l_gan_w > 0 and step > self.G_warmup:
|
if self.l_gan_w > 0 and step > self.G_warmup:
|
||||||
self.add_log_entry('l_d_real', l_d_real_log.item())
|
self.add_log_entry('l_d_real', l_d_real_log.item())
|
||||||
self.add_log_entry('l_d_fake', l_d_fake_log.item())
|
self.add_log_entry('l_d_fake', l_d_fake_log.item())
|
||||||
|
self.add_log_entry('l_d_fea_fake', l_d_fea_fake.item() * self.mega_batch_factor)
|
||||||
|
self.add_log_entry('l_d_fea_real', l_d_fea_real.item() * self.mega_batch_factor)
|
||||||
|
self.add_log_entry('l_d_fake_total', l_d_fake.item() * self.mega_batch_factor)
|
||||||
|
self.add_log_entry('l_d_real_total', l_d_real.item() * self.mega_batch_factor)
|
||||||
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
self.add_log_entry('D_fake', torch.mean(pred_d_fake.detach()))
|
||||||
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
self.add_log_entry('D_diff', torch.mean(pred_d_fake) - torch.mean(pred_d_real))
|
||||||
|
|
||||||
|
|
|
@ -237,3 +237,72 @@ class Discriminator_UNet(nn.Module):
|
||||||
def pixgan_parameters(self):
|
def pixgan_parameters(self):
|
||||||
return 3, 4
|
return 3, 4
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator_UNet_FeaOut(nn.Module):
|
||||||
|
def __init__(self, in_nc, nf):
|
||||||
|
super(Discriminator_UNet_FeaOut, self).__init__()
|
||||||
|
# [64, 128, 128]
|
||||||
|
self.conv0_0 = ConvGnLelu(in_nc, nf, kernel_size=3, bias=True, activation=False)
|
||||||
|
self.conv0_1 = ConvGnLelu(nf, nf, kernel_size=3, stride=2, bias=False)
|
||||||
|
# [64, 64, 64]
|
||||||
|
self.conv1_0 = ConvGnLelu(nf, nf * 2, kernel_size=3, bias=False)
|
||||||
|
self.conv1_1 = ConvGnLelu(nf * 2, nf * 2, kernel_size=3, stride=2, bias=False)
|
||||||
|
# [128, 32, 32]
|
||||||
|
self.conv2_0 = ConvGnLelu(nf * 2, nf * 4, kernel_size=3, bias=False)
|
||||||
|
self.conv2_1 = ConvGnLelu(nf * 4, nf * 4, kernel_size=3, stride=2, bias=False)
|
||||||
|
# [256, 16, 16]
|
||||||
|
self.conv3_0 = ConvGnLelu(nf * 4, nf * 8, kernel_size=3, bias=False)
|
||||||
|
self.conv3_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||||
|
# [512, 8, 8]
|
||||||
|
self.conv4_0 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, bias=False)
|
||||||
|
self.conv4_1 = ConvGnLelu(nf * 8, nf * 8, kernel_size=3, stride=2, bias=False)
|
||||||
|
|
||||||
|
self.up1 = ExpansionBlock(nf * 8, nf * 8, block=ConvGnLelu)
|
||||||
|
self.proc1 = ConvGnLelu(nf * 8, nf * 8, bias=False)
|
||||||
|
self.collapse1 = ConvGnLelu(nf * 8, 1, bias=True, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.up2 = ExpansionBlock(nf * 8, nf * 4, block=ConvGnLelu)
|
||||||
|
self.proc2 = ConvGnLelu(nf * 4, nf * 4, bias=False)
|
||||||
|
self.collapse2 = ConvGnLelu(nf * 4, 1, bias=True, norm=False, activation=False)
|
||||||
|
|
||||||
|
self.up3 = ExpansionBlock(nf * 4, nf * 2, block=ConvGnLelu)
|
||||||
|
self.proc3 = ConvGnLelu(nf * 2, nf * 2, bias=False)
|
||||||
|
self.collapse3 = ConvGnLelu(nf * 2, 1, bias=True, norm=False, activation=False)
|
||||||
|
|
||||||
|
def forward(self, x, output_feature_vector=False):
|
||||||
|
fea0 = self.conv0_0(x)
|
||||||
|
fea0 = self.conv0_1(fea0)
|
||||||
|
|
||||||
|
fea1 = self.conv1_0(fea0)
|
||||||
|
fea1 = self.conv1_1(fea1)
|
||||||
|
|
||||||
|
fea2 = self.conv2_0(fea1)
|
||||||
|
fea2 = self.conv2_1(fea2)
|
||||||
|
|
||||||
|
fea3 = self.conv3_0(fea2)
|
||||||
|
fea3 = self.conv3_1(fea3)
|
||||||
|
|
||||||
|
feat = self.conv4_0(fea3)
|
||||||
|
fea4 = self.conv4_1(feat)
|
||||||
|
|
||||||
|
# And the pyramid network!
|
||||||
|
u1 = self.up1(fea4, fea3)
|
||||||
|
loss1 = self.collapse1(self.proc1(u1))
|
||||||
|
u2 = self.up2(u1, fea2)
|
||||||
|
loss2 = self.collapse2(self.proc2(u2))
|
||||||
|
u3 = self.up3(u2, fea1)
|
||||||
|
loss3 = self.collapse3(self.proc3(u3))
|
||||||
|
res = loss3.shape[2:]
|
||||||
|
|
||||||
|
# Compress all of the loss values into the batch dimension. The actual loss attached to this output will
|
||||||
|
# then know how to handle them.
|
||||||
|
combined_losses = torch.cat([F.interpolate(loss1, scale_factor=4),
|
||||||
|
F.interpolate(loss2, scale_factor=2),
|
||||||
|
F.interpolate(loss3, scale_factor=1)], dim=1)
|
||||||
|
if output_feature_vector:
|
||||||
|
return combined_losses.view(-1, 1), feat
|
||||||
|
else:
|
||||||
|
return combined_losses.view(-1, 1)
|
||||||
|
|
||||||
|
def pixgan_parameters(self):
|
||||||
|
return 3, 4
|
|
@ -23,7 +23,7 @@ class GANLoss(nn.Module):
|
||||||
self.real_label_val = real_label_val
|
self.real_label_val = real_label_val
|
||||||
self.fake_label_val = fake_label_val
|
self.fake_label_val = fake_label_val
|
||||||
|
|
||||||
if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan':
|
if self.gan_type == 'gan' or self.gan_type == 'ragan' or self.gan_type == 'pixgan' or self.gan_type == "pixgan_fea":
|
||||||
self.loss = nn.BCEWithLogitsLoss()
|
self.loss = nn.BCEWithLogitsLoss()
|
||||||
elif self.gan_type == 'lsgan':
|
elif self.gan_type == 'lsgan':
|
||||||
self.loss = nn.MSELoss()
|
self.loss = nn.MSELoss()
|
||||||
|
@ -46,7 +46,7 @@ class GANLoss(nn.Module):
|
||||||
return torch.empty_like(input).fill_(self.fake_label_val)
|
return torch.empty_like(input).fill_(self.fake_label_val)
|
||||||
|
|
||||||
def forward(self, input, target_is_real):
|
def forward(self, input, target_is_real):
|
||||||
if self.gan_type == 'pixgan' and not isinstance(target_is_real, bool):
|
if 'pixgan' in self.gan_type and not isinstance(target_is_real, bool):
|
||||||
target_label = target_is_real
|
target_label = target_is_real
|
||||||
else:
|
else:
|
||||||
target_label = self.get_target_label(input, target_is_real)
|
target_label = self.get_target_label(input, target_is_real)
|
||||||
|
|
|
@ -122,6 +122,8 @@ def define_D(opt):
|
||||||
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||||
elif which_model == "discriminator_unet":
|
elif which_model == "discriminator_unet":
|
||||||
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||||
|
elif which_model == "discriminator_unet_fea":
|
||||||
|
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
|
@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_dual_srg.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg2_fdisc.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from torch import nn
|
||||||
import models.archs.SRG1_arch as srg1
|
import models.archs.SRG1_arch as srg1
|
||||||
import models.archs.SwitchedResidualGenerator_arch as srg
|
import models.archs.SwitchedResidualGenerator_arch as srg
|
||||||
import models.archs.NestedSwitchGenerator as nsg
|
import models.archs.NestedSwitchGenerator as nsg
|
||||||
|
import models.archs.discriminator_vgg_arch as disc
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
blacklisted_modules = [nn.Conv2d, nn.ReLU, nn.LeakyReLU, nn.BatchNorm2d, nn.Softmax]
|
||||||
|
@ -93,6 +94,7 @@ if __name__ == "__main__":
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
'''
|
||||||
|
'''
|
||||||
test_stability(functools.partial(srg.DualOutputSRG,
|
test_stability(functools.partial(srg.DualOutputSRG,
|
||||||
switch_depth=4,
|
switch_depth=4,
|
||||||
switch_filters=64,
|
switch_filters=64,
|
||||||
|
@ -105,7 +107,7 @@ if __name__ == "__main__":
|
||||||
upsample_factor=4),
|
upsample_factor=4),
|
||||||
torch.randn(1, 3, 32, 32),
|
torch.randn(1, 3, 32, 32),
|
||||||
device='cpu')
|
device='cpu')
|
||||||
|
'''
|
||||||
'''
|
'''
|
||||||
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
|
||||||
switch_filters=[32,32,32,32],
|
switch_filters=[32,32,32,32],
|
||||||
|
@ -126,3 +128,6 @@ if __name__ == "__main__":
|
||||||
torch.randn(1, 3, 64, 64),
|
torch.randn(1, 3, 64, 64),
|
||||||
device='cuda')
|
device='cuda')
|
||||||
'''
|
'''
|
||||||
|
test_stability(functools.partial(disc.Discriminator_UNet_FeaOut, 3, 64),
|
||||||
|
torch.randn(1,3,128,128),
|
||||||
|
device='cpu')
|
Loading…
Reference in New Issue
Block a user