Several things

- Fixes to 'after' and 'before' defs for steps (turns out they werent working)
- Feature nets take in a list of layers to extract. Not fully implemented yet.
- Fixes bugs with RAGAN
- Allows real input into generator gan to not be detached by param
This commit is contained in:
James Betker 2020-09-23 11:56:36 -06:00
parent 4ab989e015
commit 05963157c1
4 changed files with 39 additions and 20 deletions

View File

@ -178,10 +178,10 @@ class ExtensibleTrainer(BaseModel):
state = self.dstate
for step_num, s in enumerate(self.steps):
# Skip steps if mod_step doesn't line up.
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
if 'mod_step' in s.step_opt.keys() and step % s.step_opt['mod_step'] != 0:
continue
# Steps can opt out of early (or late) training, make sure that happens here.
if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']:
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
continue
# Only set requires_grad=True for the network being trained.

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
# Utilizes pretrained torchvision modules for feature extraction
class VGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm
@ -21,7 +21,8 @@ class VGGFeatureExtractor(nn.Module):
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
self.feature_layers = feature_layers
self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False

View File

@ -159,18 +159,19 @@ def define_fixed_D(opt):
# Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None):
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
if which_model == 'vgg':
# PyTorch pretrained VGG19-54, before ReLU.
if feature_layers is None:
if use_bn:
feature_layer = 49
feature_layers = [49]
else:
feature_layer = 34
feature_layers = [34]
if for_training:
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True)
else:
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True)
elif which_model == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)

View File

@ -113,20 +113,35 @@ class GeneratorGanLoss(ConfigurableLoss):
super(GeneratorGanLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
self.noise = None if 'noise' not in opt.keys() else opt['noise']
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
def forward(self, _, state):
netD = self.env['discriminators'][self.opt['discriminator']]
real = extract_params_from_state(self.opt['real'], state)
fake = extract_params_from_state(self.opt['fake'], state)
if self.noise:
nreal = []
nfake = []
for i, t in enumerate(real):
if isinstance(t, torch.Tensor):
nreal.append(t + torch.randn_like(t) * self.noise)
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
else:
nreal.append(t)
nfake.append(fake[i])
real = nreal
fake = nfake
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
pred_g_fake = netD(*fake)
return self.criterion(pred_g_fake, True)
elif self.opt['gan_type'] == 'ragan':
real = extract_params_from_state(self.opt['real'], state)
real = [r.detach() for r in real]
pred_d_real = netD(*real).detach()
pred_d_real = netD(*real)
if self.detach_real:
pred_d_real = pred_d_real.detach()
pred_g_fake = netD(*fake)
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
else:
raise NotImplementedError
@ -142,6 +157,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
def forward(self, net, state):
self.metrics = []
real = extract_params_from_state(self.opt['real'], state)
real = [r.detach() for r in real]
fake = extract_params_from_state(self.opt['fake'], state)
fake = [f.detach() for f in fake]
if self.noise:
@ -159,17 +175,18 @@ class DiscriminatorGanLoss(ConfigurableLoss):
d_real = net(*real)
d_fake = net(*fake)
if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.opt['gan_type'] in ['gan', 'pixgan']:
l_real = self.criterion(d_real, True)
l_fake = self.criterion(d_fake, False)
l_total = l_real + l_fake
return l_total
elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
return (self.criterion(d_real - torch.mean(d_fake), True) +
self.criterion(d_fake - torch.mean(d_real), False))
self.criterion(d_fake_diff, False))
else:
raise NotImplementedError