Several things

- Fixes to 'after' and 'before' defs for steps (turns out they werent working)
- Feature nets take in a list of layers to extract. Not fully implemented yet.
- Fixes bugs with RAGAN
- Allows real input into generator gan to not be detached by param
This commit is contained in:
James Betker 2020-09-23 11:56:36 -06:00
parent 4ab989e015
commit 05963157c1
4 changed files with 39 additions and 20 deletions

View File

@ -178,10 +178,10 @@ class ExtensibleTrainer(BaseModel):
state = self.dstate state = self.dstate
for step_num, s in enumerate(self.steps): for step_num, s in enumerate(self.steps):
# Skip steps if mod_step doesn't line up. # Skip steps if mod_step doesn't line up.
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0: if 'mod_step' in s.step_opt.keys() and step % s.step_opt['mod_step'] != 0:
continue continue
# Steps can opt out of early (or late) training, make sure that happens here. # Steps can opt out of early (or late) training, make sure that happens here.
if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']: if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
continue continue
# Only set requires_grad=True for the network being trained. # Only set requires_grad=True for the network being trained.

View File

@ -6,7 +6,7 @@ import torch.nn.functional as F
# Utilizes pretrained torchvision modules for feature extraction # Utilizes pretrained torchvision modules for feature extraction
class VGGFeatureExtractor(nn.Module): class VGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True,
device=torch.device('cpu')): device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__() super(VGGFeatureExtractor, self).__init__()
self.use_input_norm = use_input_norm self.use_input_norm = use_input_norm
@ -21,7 +21,8 @@ class VGGFeatureExtractor(nn.Module):
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
self.register_buffer('mean', mean) self.register_buffer('mean', mean)
self.register_buffer('std', std) self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) self.feature_layers = feature_layers
self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)])
# No need to BP to variable # No need to BP to variable
for k, v in self.features.named_parameters(): for k, v in self.features.named_parameters():
v.requires_grad = False v.requires_grad = False

View File

@ -159,18 +159,19 @@ def define_fixed_D(opt):
# Define network used for perceptual loss # Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None): def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
if which_model == 'vgg': if which_model == 'vgg':
# PyTorch pretrained VGG19-54, before ReLU. # PyTorch pretrained VGG19-54, before ReLU.
if use_bn: if feature_layers is None:
feature_layer = 49 if use_bn:
else: feature_layers = [49]
feature_layer = 34 else:
feature_layers = [34]
if for_training: if for_training:
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True) use_input_norm=True)
else: else:
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True) use_input_norm=True)
elif which_model == 'wide_resnet': elif which_model == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True) netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)

View File

@ -113,20 +113,35 @@ class GeneratorGanLoss(ConfigurableLoss):
super(GeneratorGanLoss, self).__init__(opt, env) super(GeneratorGanLoss, self).__init__(opt, env)
self.opt = opt self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
self.noise = None if 'noise' not in opt.keys() else opt['noise']
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
def forward(self, _, state): def forward(self, _, state):
netD = self.env['discriminators'][self.opt['discriminator']] netD = self.env['discriminators'][self.opt['discriminator']]
real = extract_params_from_state(self.opt['real'], state)
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
if self.noise:
nreal = []
nfake = []
for i, t in enumerate(real):
if isinstance(t, torch.Tensor):
nreal.append(t + torch.randn_like(t) * self.noise)
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
else:
nreal.append(t)
nfake.append(fake[i])
real = nreal
fake = nfake
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
pred_g_fake = netD(*fake) pred_g_fake = netD(*fake)
return self.criterion(pred_g_fake, True) return self.criterion(pred_g_fake, True)
elif self.opt['gan_type'] == 'ragan': elif self.opt['gan_type'] == 'ragan':
real = extract_params_from_state(self.opt['real'], state) pred_d_real = netD(*real)
real = [r.detach() for r in real] if self.detach_real:
pred_d_real = netD(*real).detach() pred_d_real = pred_d_real.detach()
pred_g_fake = netD(*fake) pred_g_fake = netD(*fake)
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
else: else:
raise NotImplementedError raise NotImplementedError
@ -142,6 +157,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
def forward(self, net, state): def forward(self, net, state):
self.metrics = [] self.metrics = []
real = extract_params_from_state(self.opt['real'], state) real = extract_params_from_state(self.opt['real'], state)
real = [r.detach() for r in real]
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
fake = [f.detach() for f in fake] fake = [f.detach() for f in fake]
if self.noise: if self.noise:
@ -159,17 +175,18 @@ class DiscriminatorGanLoss(ConfigurableLoss):
d_real = net(*real) d_real = net(*real)
d_fake = net(*fake) d_fake = net(*fake)
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.opt['gan_type'] in ['gan', 'pixgan']: if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
l_real = self.criterion(d_real, True) l_real = self.criterion(d_real, True)
l_fake = self.criterion(d_fake, False) l_fake = self.criterion(d_fake, False)
l_total = l_real + l_fake l_total = l_real + l_fake
return l_total return l_total
elif self.opt['gan_type'] == 'ragan': elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
return (self.criterion(d_real - torch.mean(d_fake), True) + return (self.criterion(d_real - torch.mean(d_fake), True) +
self.criterion(d_fake - torch.mean(d_real), False)) self.criterion(d_fake_diff, False))
else: else:
raise NotImplementedError raise NotImplementedError