forked from mrq/DL-Art-School
Several things
- Fixes to 'after' and 'before' defs for steps (turns out they werent working) - Feature nets take in a list of layers to extract. Not fully implemented yet. - Fixes bugs with RAGAN - Allows real input into generator gan to not be detached by param
This commit is contained in:
parent
4ab989e015
commit
05963157c1
|
@ -178,10 +178,10 @@ class ExtensibleTrainer(BaseModel):
|
|||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
# Skip steps if mod_step doesn't line up.
|
||||
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
|
||||
if 'mod_step' in s.step_opt.keys() and step % s.step_opt['mod_step'] != 0:
|
||||
continue
|
||||
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||
if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']:
|
||||
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
||||
continue
|
||||
|
||||
# Only set requires_grad=True for the network being trained.
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||
# Utilizes pretrained torchvision modules for feature extraction
|
||||
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
||||
def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True,
|
||||
device=torch.device('cpu')):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
self.use_input_norm = use_input_norm
|
||||
|
@ -21,7 +21,8 @@ class VGGFeatureExtractor(nn.Module):
|
|||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||
self.register_buffer('mean', mean)
|
||||
self.register_buffer('std', std)
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
||||
self.feature_layers = feature_layers
|
||||
self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)])
|
||||
# No need to BP to variable
|
||||
for k, v in self.features.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
|
|
@ -159,18 +159,19 @@ def define_fixed_D(opt):
|
|||
|
||||
|
||||
# Define network used for perceptual loss
|
||||
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None):
|
||||
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
|
||||
if which_model == 'vgg':
|
||||
# PyTorch pretrained VGG19-54, before ReLU.
|
||||
if feature_layers is None:
|
||||
if use_bn:
|
||||
feature_layer = 49
|
||||
feature_layers = [49]
|
||||
else:
|
||||
feature_layer = 34
|
||||
feature_layers = [34]
|
||||
if for_training:
|
||||
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
||||
use_input_norm=True)
|
||||
else:
|
||||
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
||||
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
||||
use_input_norm=True)
|
||||
elif which_model == 'wide_resnet':
|
||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
|
||||
|
|
|
@ -113,20 +113,35 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
super(GeneratorGanLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
|
||||
|
||||
def forward(self, _, state):
|
||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||
real = extract_params_from_state(self.opt['real'], state)
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
if self.noise:
|
||||
nreal = []
|
||||
nfake = []
|
||||
for i, t in enumerate(real):
|
||||
if isinstance(t, torch.Tensor):
|
||||
nreal.append(t + torch.randn_like(t) * self.noise)
|
||||
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
||||
else:
|
||||
nreal.append(t)
|
||||
nfake.append(fake[i])
|
||||
real = nreal
|
||||
fake = nfake
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||
pred_g_fake = netD(*fake)
|
||||
return self.criterion(pred_g_fake, True)
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
real = extract_params_from_state(self.opt['real'], state)
|
||||
real = [r.detach() for r in real]
|
||||
pred_d_real = netD(*real).detach()
|
||||
pred_d_real = netD(*real)
|
||||
if self.detach_real:
|
||||
pred_d_real = pred_d_real.detach()
|
||||
pred_g_fake = netD(*fake)
|
||||
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -142,6 +157,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
real = extract_params_from_state(self.opt['real'], state)
|
||||
real = [r.detach() for r in real]
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
fake = [f.detach() for f in fake]
|
||||
if self.noise:
|
||||
|
@ -159,17 +175,18 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
d_real = net(*real)
|
||||
d_fake = net(*fake)
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||
l_real = self.criterion(d_real, True)
|
||||
l_fake = self.criterion(d_fake, False)
|
||||
l_total = l_real + l_fake
|
||||
return l_total
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
d_fake_diff = d_fake - torch.mean(d_real)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||
self.criterion(d_fake - torch.mean(d_real), False))
|
||||
self.criterion(d_fake_diff, False))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user