Interpreted feature loss to extensibletrainer
This commit is contained in:
parent
886d59d5df
commit
8b52d46847
|
@ -2,7 +2,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from models.networks import define_F
|
from models.networks import define_F
|
||||||
from models.loss import GANLoss
|
from models.loss import GANLoss
|
||||||
from torchvision.utils import save_image
|
|
||||||
|
|
||||||
|
|
||||||
def create_generator_loss(opt_loss, env):
|
def create_generator_loss(opt_loss, env):
|
||||||
|
@ -11,6 +10,8 @@ def create_generator_loss(opt_loss, env):
|
||||||
return PixLoss(opt_loss, env)
|
return PixLoss(opt_loss, env)
|
||||||
elif type == 'feature':
|
elif type == 'feature':
|
||||||
return FeatureLoss(opt_loss, env)
|
return FeatureLoss(opt_loss, env)
|
||||||
|
elif type == 'interpreted_feature':
|
||||||
|
return InterpretedFeatureLoss(opt_loss, env)
|
||||||
elif type == 'generator_gan':
|
elif type == 'generator_gan':
|
||||||
return GeneratorGanLoss(opt_loss, env)
|
return GeneratorGanLoss(opt_loss, env)
|
||||||
elif type == 'discriminator_gan':
|
elif type == 'discriminator_gan':
|
||||||
|
@ -57,7 +58,8 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
super(FeatureLoss, self).__init__(opt, env)
|
super(FeatureLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
self.netF = define_F(which_model=opt['which_model_F'],
|
||||||
|
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||||
if not env['opt']['dist']:
|
if not env['opt']['dist']:
|
||||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||||
|
|
||||||
|
@ -68,6 +70,28 @@ class FeatureLoss(ConfigurableLoss):
|
||||||
return self.criterion(logits_fake, logits_real)
|
return self.criterion(logits_fake, logits_real)
|
||||||
|
|
||||||
|
|
||||||
|
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
|
||||||
|
# network which was trained to replicate that embedding on an altered input space (for example, LR or greyscale) to
|
||||||
|
# compute the embedding in the generated space. Useful for weakening the influence of the feature network in controlled
|
||||||
|
# ways.
|
||||||
|
class InterpretedFeatureLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
||||||
|
self.opt = opt
|
||||||
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
|
self.netF_real = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||||
|
self.netF_gen = define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
||||||
|
if not env['opt']['dist']:
|
||||||
|
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
||||||
|
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_real = self.netF_real(state[self.opt['real']])
|
||||||
|
logits_fake = self.netF_gen(state[self.opt['fake']])
|
||||||
|
return self.criterion(logits_fake, logits_real)
|
||||||
|
|
||||||
|
|
||||||
class GeneratorGanLoss(ConfigurableLoss):
|
class GeneratorGanLoss(ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(GeneratorGanLoss, self).__init__(opt, env)
|
super(GeneratorGanLoss, self).__init__(opt, env)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user