forked from mrq/DL-Art-School
Several things
- Fixes to 'after' and 'before' defs for steps (turns out they werent working) - Feature nets take in a list of layers to extract. Not fully implemented yet. - Fixes bugs with RAGAN - Allows real input into generator gan to not be detached by param
This commit is contained in:
parent
4ab989e015
commit
05963157c1
|
@ -178,10 +178,10 @@ class ExtensibleTrainer(BaseModel):
|
||||||
state = self.dstate
|
state = self.dstate
|
||||||
for step_num, s in enumerate(self.steps):
|
for step_num, s in enumerate(self.steps):
|
||||||
# Skip steps if mod_step doesn't line up.
|
# Skip steps if mod_step doesn't line up.
|
||||||
if 'mod_step' in s.opt.keys() and step % s.opt['mod_step'] != 0:
|
if 'mod_step' in s.step_opt.keys() and step % s.step_opt['mod_step'] != 0:
|
||||||
continue
|
continue
|
||||||
# Steps can opt out of early (or late) training, make sure that happens here.
|
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||||
if 'after' in s.opt.keys() and step < s.opt['after'] or 'before' in s.opt.keys() and step > s.opt['before']:
|
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Only set requires_grad=True for the network being trained.
|
# Only set requires_grad=True for the network being trained.
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
||||||
# Utilizes pretrained torchvision modules for feature extraction
|
# Utilizes pretrained torchvision modules for feature extraction
|
||||||
|
|
||||||
class VGGFeatureExtractor(nn.Module):
|
class VGGFeatureExtractor(nn.Module):
|
||||||
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
|
def __init__(self, feature_layers=[34], use_bn=False, use_input_norm=True,
|
||||||
device=torch.device('cpu')):
|
device=torch.device('cpu')):
|
||||||
super(VGGFeatureExtractor, self).__init__()
|
super(VGGFeatureExtractor, self).__init__()
|
||||||
self.use_input_norm = use_input_norm
|
self.use_input_norm = use_input_norm
|
||||||
|
@ -21,7 +21,8 @@ class VGGFeatureExtractor(nn.Module):
|
||||||
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
# [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
|
||||||
self.register_buffer('mean', mean)
|
self.register_buffer('mean', mean)
|
||||||
self.register_buffer('std', std)
|
self.register_buffer('std', std)
|
||||||
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
self.feature_layers = feature_layers
|
||||||
|
self.features = nn.Sequential(*list(model.features.children())[:(max(feature_layers) + 1)])
|
||||||
# No need to BP to variable
|
# No need to BP to variable
|
||||||
for k, v in self.features.named_parameters():
|
for k, v in self.features.named_parameters():
|
||||||
v.requires_grad = False
|
v.requires_grad = False
|
||||||
|
|
|
@ -159,18 +159,19 @@ def define_fixed_D(opt):
|
||||||
|
|
||||||
|
|
||||||
# Define network used for perceptual loss
|
# Define network used for perceptual loss
|
||||||
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None):
|
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
|
||||||
if which_model == 'vgg':
|
if which_model == 'vgg':
|
||||||
# PyTorch pretrained VGG19-54, before ReLU.
|
# PyTorch pretrained VGG19-54, before ReLU.
|
||||||
if use_bn:
|
if feature_layers is None:
|
||||||
feature_layer = 49
|
if use_bn:
|
||||||
else:
|
feature_layers = [49]
|
||||||
feature_layer = 34
|
else:
|
||||||
|
feature_layers = [34]
|
||||||
if for_training:
|
if for_training:
|
||||||
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
||||||
use_input_norm=True)
|
use_input_norm=True)
|
||||||
else:
|
else:
|
||||||
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
||||||
use_input_norm=True)
|
use_input_norm=True)
|
||||||
elif which_model == 'wide_resnet':
|
elif which_model == 'wide_resnet':
|
||||||
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
|
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
|
||||||
|
|
|
@ -113,20 +113,35 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
super(GeneratorGanLoss, self).__init__(opt, env)
|
super(GeneratorGanLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
|
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||||
|
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||||
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
|
if self.noise:
|
||||||
|
nreal = []
|
||||||
|
nfake = []
|
||||||
|
for i, t in enumerate(real):
|
||||||
|
if isinstance(t, torch.Tensor):
|
||||||
|
nreal.append(t + torch.randn_like(t) * self.noise)
|
||||||
|
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
||||||
|
else:
|
||||||
|
nreal.append(t)
|
||||||
|
nfake.append(fake[i])
|
||||||
|
real = nreal
|
||||||
|
fake = nfake
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||||
pred_g_fake = netD(*fake)
|
pred_g_fake = netD(*fake)
|
||||||
return self.criterion(pred_g_fake, True)
|
return self.criterion(pred_g_fake, True)
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
real = extract_params_from_state(self.opt['real'], state)
|
pred_d_real = netD(*real)
|
||||||
real = [r.detach() for r in real]
|
if self.detach_real:
|
||||||
pred_d_real = netD(*real).detach()
|
pred_d_real = pred_d_real.detach()
|
||||||
pred_g_fake = netD(*fake)
|
pred_g_fake = netD(*fake)
|
||||||
return (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -142,6 +157,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
self.metrics = []
|
||||||
real = extract_params_from_state(self.opt['real'], state)
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
|
real = [r.detach() for r in real]
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
fake = [f.detach() for f in fake]
|
fake = [f.detach() for f in fake]
|
||||||
if self.noise:
|
if self.noise:
|
||||||
|
@ -159,17 +175,18 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
d_real = net(*real)
|
d_real = net(*real)
|
||||||
d_fake = net(*fake)
|
d_fake = net(*fake)
|
||||||
|
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
|
||||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||||
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
|
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||||
l_real = self.criterion(d_real, True)
|
l_real = self.criterion(d_real, True)
|
||||||
l_fake = self.criterion(d_fake, False)
|
l_fake = self.criterion(d_fake, False)
|
||||||
l_total = l_real + l_fake
|
l_total = l_real + l_fake
|
||||||
return l_total
|
return l_total
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
|
d_fake_diff = d_fake - torch.mean(d_real)
|
||||||
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||||
self.criterion(d_fake - torch.mean(d_real), False))
|
self.criterion(d_fake_diff, False))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user