Interpreted feature loss to extensibletrainer

This commit is contained in:
James Betker 2020-09-02 10:08:24 -06:00
parent 886d59d5df
commit 8b52d46847

View File

@ -2,7 +2,6 @@ import torch
import torch.nn as nn
from models.networks import define_F
from models.loss import GANLoss
from torchvision.utils import save_image
def create_generator_loss(opt_loss, env):
@ -11,6 +10,8 @@ def create_generator_loss(opt_loss, env):
return PixLoss(opt_loss, env)
elif type == 'feature':
return FeatureLoss(opt_loss, env)
elif type == 'interpreted_feature':
return InterpretedFeatureLoss(opt_loss, env)
elif type == 'generator_gan':
return GeneratorGanLoss(opt_loss, env)
elif type == 'discriminator_gan':
@ -57,7 +58,8 @@ class FeatureLoss(ConfigurableLoss):
super(FeatureLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device'])
self.netF = define_F(which_model=opt['which_model_F'],
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
if not env['opt']['dist']:
self.netF = torch.nn.parallel.DataParallel(self.netF)
@ -68,6 +70,28 @@ class FeatureLoss(ConfigurableLoss):
return self.criterion(logits_fake, logits_real)
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
# network which was trained to replicate that embedding on an altered input space (for example, LR or greyscale) to
# compute the embedding in the generated space. Useful for weakening the influence of the feature network in controlled
# ways.
class InterpretedFeatureLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(InterpretedFeatureLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
self.netF_real = define_F(which_model=opt['which_model_F']).to(self.env['device'])
self.netF_gen = define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
if not env['opt']['dist']:
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
def forward(self, net, state):
with torch.no_grad():
logits_real = self.netF_real(state[self.opt['real']])
logits_fake = self.netF_gen(state[self.opt['fake']])
return self.criterion(logits_fake, logits_real)
class GeneratorGanLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(GeneratorGanLoss, self).__init__(opt, env)