diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index d4133135..8b648be2 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn from models.networks import define_F from models.loss import GANLoss -from torchvision.utils import save_image def create_generator_loss(opt_loss, env): @@ -11,6 +10,8 @@ def create_generator_loss(opt_loss, env): return PixLoss(opt_loss, env) elif type == 'feature': return FeatureLoss(opt_loss, env) + elif type == 'interpreted_feature': + return InterpretedFeatureLoss(opt_loss, env) elif type == 'generator_gan': return GeneratorGanLoss(opt_loss, env) elif type == 'discriminator_gan': @@ -57,7 +58,8 @@ class FeatureLoss(ConfigurableLoss): super(FeatureLoss, self).__init__(opt, env) self.opt = opt self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) - self.netF = define_F(which_model=opt['which_model_F']).to(self.env['device']) + self.netF = define_F(which_model=opt['which_model_F'], + load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device']) if not env['opt']['dist']: self.netF = torch.nn.parallel.DataParallel(self.netF) @@ -68,6 +70,28 @@ class FeatureLoss(ConfigurableLoss): return self.criterion(logits_fake, logits_real) +# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second +# network which was trained to replicate that embedding on an altered input space (for example, LR or greyscale) to +# compute the embedding in the generated space. Useful for weakening the influence of the feature network in controlled +# ways. +class InterpretedFeatureLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(InterpretedFeatureLoss, self).__init__(opt, env) + self.opt = opt + self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) + self.netF_real = define_F(which_model=opt['which_model_F']).to(self.env['device']) + self.netF_gen = define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device']) + if not env['opt']['dist']: + self.netF_real = torch.nn.parallel.DataParallel(self.netF_real) + self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen) + + def forward(self, net, state): + with torch.no_grad(): + logits_real = self.netF_real(state[self.opt['real']]) + logits_fake = self.netF_gen(state[self.opt['fake']]) + return self.criterion(logits_fake, logits_real) + + class GeneratorGanLoss(ConfigurableLoss): def __init__(self, opt, env): super(GeneratorGanLoss, self).__init__(opt, env)