forked from mrq/DL-Art-School
Interpreted feature loss to extensibletrainer
This commit is contained in:
parent
886d59d5df
commit
8b52d46847
|
@ -2,7 +2,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from models.networks import define_F
|
||||
from models.loss import GANLoss
|
||||
from torchvision.utils import save_image
|
||||
|
||||
|
||||
def create_generator_loss(opt_loss, env):
|
||||
|
@ -11,6 +10,8 @@ def create_generator_loss(opt_loss, env):
|
|||
return PixLoss(opt_loss, env)
|
||||
elif type == 'feature':
|
||||
return FeatureLoss(opt_loss, env)
|
||||
elif type == 'interpreted_feature':
|
||||
return InterpretedFeatureLoss(opt_loss, env)
|
||||
elif type == 'generator_gan':
|
||||
return GeneratorGanLoss(opt_loss, env)
|
||||
elif type == 'discriminator_gan':
|
||||
|
@ -57,7 +58,8 @@ class FeatureLoss(ConfigurableLoss):
|
|||
super(FeatureLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||
self.netF = define_F(which_model=opt['which_model_F'],
|
||||
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
||||
if not env['opt']['dist']:
|
||||
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
||||
|
||||
|
@ -68,6 +70,28 @@ class FeatureLoss(ConfigurableLoss):
|
|||
return self.criterion(logits_fake, logits_real)
|
||||
|
||||
|
||||
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
|
||||
# network which was trained to replicate that embedding on an altered input space (for example, LR or greyscale) to
|
||||
# compute the embedding in the generated space. Useful for weakening the influence of the feature network in controlled
|
||||
# ways.
|
||||
class InterpretedFeatureLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
self.netF_real = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
||||
self.netF_gen = define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
||||
if not env['opt']['dist']:
|
||||
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
||||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
||||
|
||||
def forward(self, net, state):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF_real(state[self.opt['real']])
|
||||
logits_fake = self.netF_gen(state[self.opt['fake']])
|
||||
return self.criterion(logits_fake, logits_real)
|
||||
|
||||
|
||||
class GeneratorGanLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(GeneratorGanLoss, self).__init__(opt, env)
|
||||
|
|
Loading…
Reference in New Issue
Block a user