555 lines
25 KiB
Python
555 lines
25 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.cuda.amp import autocast
|
|
|
|
from models.networks import define_F
|
|
from models.loss import GANLoss
|
|
import random
|
|
import functools
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
def create_loss(opt_loss, env):
|
|
type = opt_loss['type']
|
|
if 'teco_' in type:
|
|
from models.steps.tecogan_losses import create_teco_loss
|
|
return create_teco_loss(opt_loss, env)
|
|
elif type == 'pix':
|
|
return PixLoss(opt_loss, env)
|
|
elif type == 'direct':
|
|
return DirectLoss(opt_loss, env)
|
|
elif type == 'feature':
|
|
return FeatureLoss(opt_loss, env)
|
|
elif type == 'interpreted_feature':
|
|
return InterpretedFeatureLoss(opt_loss, env)
|
|
elif type == 'generator_gan':
|
|
return GeneratorGanLoss(opt_loss, env)
|
|
elif type == 'discriminator_gan':
|
|
return DiscriminatorGanLoss(opt_loss, env)
|
|
elif type == 'geometric':
|
|
return GeometricSimilarityGeneratorLoss(opt_loss, env)
|
|
elif type == 'translational':
|
|
return TranslationInvarianceLoss(opt_loss, env)
|
|
elif type == 'recursive':
|
|
return RecursiveInvarianceLoss(opt_loss, env)
|
|
elif type == 'recurrent':
|
|
return RecurrentLoss(opt_loss, env)
|
|
elif type == 'for_element':
|
|
return ForElementLoss(opt_loss, env)
|
|
elif type == 'stylegan2_divergence':
|
|
return StyleGan2DivergenceLoss(opt_loss, env)
|
|
elif type == 'stylegan2_pathlen':
|
|
return StyleGan2PathLengthLoss(opt_loss, env)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
|
|
def extract_params_from_state(params: object, state: object, root: object = True) -> object:
|
|
if isinstance(params, list) or isinstance(params, tuple):
|
|
p = [extract_params_from_state(r, state, False) for r in params]
|
|
elif isinstance(params, str):
|
|
if params == 'None':
|
|
p = None
|
|
else:
|
|
p = state[params]
|
|
else:
|
|
p = params
|
|
# The root return must always be a list.
|
|
if root and not isinstance(p, list):
|
|
p = [p]
|
|
return p
|
|
|
|
|
|
class ConfigurableLoss(nn.Module):
|
|
def __init__(self, opt, env):
|
|
super(ConfigurableLoss, self).__init__()
|
|
self.opt = opt
|
|
self.env = env
|
|
self.metrics = []
|
|
|
|
# net is either a scalar network being trained or a list of networks being trained, depending on the configuration.
|
|
def forward(self, net, state):
|
|
raise NotImplementedError
|
|
|
|
def extra_metrics(self):
|
|
return self.metrics
|
|
|
|
def clear_metrics(self):
|
|
self.metrics = []
|
|
|
|
|
|
def get_basic_criterion_for_name(name, device):
|
|
if name == 'l1':
|
|
return nn.L1Loss().to(device)
|
|
elif name == 'l2':
|
|
return nn.MSELoss().to(device)
|
|
elif name == 'cosine':
|
|
return nn.CosineEmbeddingLoss().to(device)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
class PixLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(PixLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
|
self.real_scale = opt['real_scale'] if 'real_scale' in opt.keys() else 1
|
|
self.real_offset = opt['real_offset'] if 'real_offset' in opt.keys() else 0
|
|
self.report_metrics = opt['report_metrics'] if 'report_metrics' in opt.keys() else False
|
|
|
|
def forward(self, _, state):
|
|
real = state[self.opt['real']] * self.real_scale + float(self.real_offset)
|
|
fake = state[self.opt['fake']]
|
|
if self.report_metrics:
|
|
self.metrics.append(("real_pix_mean_histogram", torch.mean(real, dim=[1,2,3]).detach()))
|
|
self.metrics.append(("fake_pix_mean_histogram", torch.mean(fake, dim=[1,2,3]).detach()))
|
|
self.metrics.append(("real_pix_std", torch.std(real).detach()))
|
|
self.metrics.append(("fake_pix_std", torch.std(fake).detach()))
|
|
return self.criterion(fake.float(), real.float())
|
|
|
|
|
|
# Loss defined by averaging the input tensor across all dimensions an optionally inverting it.
|
|
class DirectLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(DirectLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.inverted = opt['inverted'] if 'inverted' in opt.keys() else False
|
|
self.key = opt['key']
|
|
|
|
def forward(self, _, state):
|
|
if self.inverted:
|
|
return -torch.mean(state[self.key])
|
|
else:
|
|
return torch.mean(state[self.key])
|
|
|
|
|
|
class FeatureLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(FeatureLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
|
self.netF = define_F(which_model=opt['which_model_F'],
|
|
load_path=opt['load_path'] if 'load_path' in opt.keys() else None).to(self.env['device'])
|
|
if not env['opt']['dist']:
|
|
self.netF = torch.nn.parallel.DataParallel(self.netF)
|
|
|
|
def forward(self, _, state):
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
with torch.no_grad():
|
|
logits_real = self.netF(state[self.opt['real']])
|
|
logits_fake = self.netF(state[self.opt['fake']])
|
|
if self.opt['criterion'] == 'cosine':
|
|
return self.criterion(logits_fake.float(), logits_real.float(), torch.ones(1, device=logits_fake.device))
|
|
else:
|
|
return self.criterion(logits_fake.float(), logits_real.float())
|
|
|
|
|
|
# Special form of feature loss which first computes the feature embedding for the truth space, then uses a second
|
|
# network which was trained to replicate that embedding on an altered input space (for example, LR or greyscale) to
|
|
# compute the embedding in the generated space. Useful for weakening the influence of the feature network in controlled
|
|
# ways.
|
|
class InterpretedFeatureLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(InterpretedFeatureLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
|
self.netF_real = define_F(which_model=opt['which_model_F']).to(self.env['device'])
|
|
self.netF_gen = define_F(which_model=opt['which_model_F'], load_path=opt['load_path']).to(self.env['device'])
|
|
if not env['opt']['dist']:
|
|
self.netF_real = torch.nn.parallel.DataParallel(self.netF_real)
|
|
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
|
|
|
def forward(self, _, state):
|
|
logits_real = self.netF_real(state[self.opt['real']])
|
|
logits_fake = self.netF_gen(state[self.opt['fake']])
|
|
return self.criterion(logits_fake.float(), logits_real.float())
|
|
|
|
|
|
class GeneratorGanLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(GeneratorGanLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
|
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
|
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
|
|
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
|
if self.min_loss != 0:
|
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
|
self.rb_ptr = 0
|
|
self.losses_computed = 0
|
|
|
|
def forward(self, _, state):
|
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
|
real = extract_params_from_state(self.opt['real'], state)
|
|
fake = extract_params_from_state(self.opt['fake'], state)
|
|
if self.noise:
|
|
nreal = []
|
|
nfake = []
|
|
for i, t in enumerate(real):
|
|
if isinstance(t, torch.Tensor):
|
|
nreal.append(t + torch.randn_like(t) * self.noise)
|
|
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
|
else:
|
|
nreal.append(t)
|
|
nfake.append(fake[i])
|
|
real = nreal
|
|
fake = nfake
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
|
pred_g_fake = netD(*fake)
|
|
loss = self.criterion(pred_g_fake, True)
|
|
elif self.opt['gan_type'] == 'ragan':
|
|
pred_d_real = netD(*real)
|
|
if self.detach_real:
|
|
pred_d_real = pred_d_real.detach()
|
|
pred_g_fake = netD(*fake)
|
|
d_fake_diff = pred_g_fake - torch.mean(pred_d_real)
|
|
self.metrics.append(("d_fake", torch.mean(pred_g_fake)))
|
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
|
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
|
self.criterion(d_fake_diff, True)) / 2
|
|
else:
|
|
raise NotImplementedError
|
|
if self.min_loss != 0:
|
|
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
|
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
|
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
|
return 0
|
|
self.losses_computed += 1
|
|
self.metrics.append(("loss_counter", self.losses_computed))
|
|
return loss
|
|
|
|
|
|
class DiscriminatorGanLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(DiscriminatorGanLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
|
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
|
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
|
if self.min_loss != 0:
|
|
assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
|
self.rb_ptr = 0
|
|
self.losses_computed = 0
|
|
|
|
def forward(self, net, state):
|
|
real = extract_params_from_state(self.opt['real'], state)
|
|
real = [r.detach() for r in real]
|
|
fake = extract_params_from_state(self.opt['fake'], state)
|
|
fake = [f.detach() for f in fake]
|
|
if self.noise:
|
|
nreal = []
|
|
nfake = []
|
|
for i, t in enumerate(real):
|
|
if isinstance(t, torch.Tensor):
|
|
nreal.append(t + torch.randn_like(t) * self.noise)
|
|
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
|
else:
|
|
nreal.append(t)
|
|
nfake.append(fake[i])
|
|
real = nreal
|
|
fake = nfake
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
d_real = net(*real)
|
|
d_fake = net(*fake)
|
|
|
|
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
|
self.metrics.append(("d_real", torch.mean(d_real)))
|
|
l_real = self.criterion(d_real, True)
|
|
l_fake = self.criterion(d_fake, False)
|
|
l_total = l_real + l_fake
|
|
loss = l_total
|
|
elif self.opt['gan_type'] == 'ragan' or self.opt['gan_type'] == 'max_spread':
|
|
d_fake_diff = d_fake - torch.mean(d_real)
|
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
|
loss = (self.criterion(d_real - torch.mean(d_fake), True) +
|
|
self.criterion(d_fake_diff, False))
|
|
else:
|
|
raise NotImplementedError
|
|
if self.min_loss != 0:
|
|
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
|
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
|
self.metrics.append(("loss_counter", self.losses_computed))
|
|
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
|
return 0
|
|
self.losses_computed += 1
|
|
return loss
|
|
|
|
|
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
|
# input that has been altered randomly by rotation or flip.
|
|
# The "real" parameter to this loss is the actual output of the generator (from an injection point)
|
|
# The "fake" parameter is the LR input that produced the "real" parameter when fed through the generator.
|
|
class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(GeometricSimilarityGeneratorLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.generator = opt['generator']
|
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
|
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
|
|
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
|
self.detach_fake = opt['detach_fake'] if 'detach_fake' in opt.keys() else False
|
|
|
|
# Returns a random alteration and its counterpart (that undoes the alteration)
|
|
def random_alteration(self):
|
|
return random.choice([(functools.partial(torch.flip, dims=(2,)), functools.partial(torch.flip, dims=(2,))),
|
|
(functools.partial(torch.flip, dims=(3,)), functools.partial(torch.flip, dims=(3,))),
|
|
(functools.partial(torch.rot90, k=1, dims=[2,3]), functools.partial(torch.rot90, k=3, dims=[2,3])),
|
|
(functools.partial(torch.rot90, k=2, dims=[2,3]), functools.partial(torch.rot90, k=2, dims=[2,3])),
|
|
(functools.partial(torch.rot90, k=3, dims=[2,3]), functools.partial(torch.rot90, k=1, dims=[2,3]))])
|
|
|
|
def forward(self, net, state):
|
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
|
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
|
fake = extract_params_from_state(self.opt['fake'], state)
|
|
alteration, undo_fn = self.random_alteration()
|
|
altered = []
|
|
for i, t in enumerate(fake):
|
|
if i == self.gen_input_for_alteration:
|
|
altered.append(alteration(t))
|
|
else:
|
|
altered.append(t)
|
|
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
if self.detach_fake:
|
|
with torch.no_grad():
|
|
upsampled_altered = net(*altered)
|
|
else:
|
|
upsampled_altered = net(*altered)
|
|
|
|
if self.gen_output_to_use is not None:
|
|
upsampled_altered = upsampled_altered[self.gen_output_to_use]
|
|
|
|
# Undo alteration on HR image
|
|
upsampled_altered = undo_fn(upsampled_altered)
|
|
|
|
if self.opt['criterion'] == 'cosine':
|
|
return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device))
|
|
else:
|
|
return self.criterion(state[self.opt['real']].float(), upsampled_altered.float())
|
|
|
|
|
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
|
# input that has been translated in a random direction.
|
|
# The "real" parameter to this loss is the actual output of the generator on the top left image patch.
|
|
# The "fake" parameter is the output base fed into a ImagePatchInjector.
|
|
class TranslationInvarianceLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(TranslationInvarianceLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.generator = opt['generator']
|
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
|
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
|
|
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
|
self.patch_size = opt['patch_size']
|
|
self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size
|
|
self.detach_fake = opt['detach_fake']
|
|
assert(self.patch_size > self.overlap)
|
|
|
|
def forward(self, net, state):
|
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
|
# The <net> parameter is not reliable for generator losses since often they are combined with many networks.
|
|
|
|
border_sz = self.patch_size - self.overlap
|
|
translation = random.choice([("top_right", border_sz, border_sz+self.overlap, 0, self.overlap),
|
|
("bottom_left", 0, self.overlap, border_sz, border_sz+self.overlap),
|
|
("bottom_right", 0, self.overlap, 0, self.overlap)])
|
|
trans_name, hl, hh, wl, wh = translation
|
|
# Change the "fake" input name that we are translating to one that specifies the random translation.
|
|
fake = self.opt['fake'].copy()
|
|
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
|
|
input = extract_params_from_state(fake, state)
|
|
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
if self.detach_fake:
|
|
with torch.no_grad():
|
|
trans_output = net(*input)
|
|
else:
|
|
trans_output = net(*input)
|
|
|
|
if self.gen_output_to_use is not None:
|
|
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
|
else:
|
|
fake_shared_output = trans_output[:, :, hl:hh, wl:wh]
|
|
|
|
# The "real" input is assumed to always come from the top left tile.
|
|
gen_output = state[self.opt['real']]
|
|
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
|
|
|
|
if self.opt['criterion'] == 'cosine':
|
|
return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device))
|
|
else:
|
|
return self.criterion(fake_shared_output.float(), real_shared_output.float())
|
|
|
|
|
|
# Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is
|
|
# that the generator's outputs do not change on repeated forward passes.
|
|
# The "real" parameter to this loss is the actual output of the generator.
|
|
# The "fake" parameter is the expected inputs that should be fed into the generator. 'input_alteration_index' is changed
|
|
# so that it feeds the recursive input.
|
|
class RecursiveInvarianceLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(RecursiveInvarianceLoss, self).__init__(opt, env)
|
|
self.opt = opt
|
|
self.generator = opt['generator']
|
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
|
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
|
|
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
|
self.recursive_depth = opt['recursive_depth'] # How many times to recursively feed the output of the generator back into itself
|
|
self.downsample_factor = opt['downsample_factor'] # Just 1/opt['scale']. Necessary since this loss doesnt have access to opt['scale'].
|
|
assert(self.recursive_depth > 0)
|
|
|
|
def forward(self, net, state):
|
|
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
|
# The <net> parameter is not reliable for generator losses since they can be combined with many networks.
|
|
|
|
gen_output = state[self.opt['real']]
|
|
recurrent_gen_output = gen_output
|
|
|
|
fake = self.opt['fake'].copy()
|
|
input = extract_params_from_state(fake, state)
|
|
for i in range(self.recursive_depth):
|
|
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
|
|
with autocast(enabled=self.env['opt']['fp16']):
|
|
recurrent_gen_output = net(*input)[self.gen_output_to_use]
|
|
|
|
compare_real = gen_output
|
|
compare_fake = recurrent_gen_output
|
|
if self.opt['criterion'] == 'cosine':
|
|
return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device))
|
|
else:
|
|
return self.criterion(compare_real.float(), compare_fake.float())
|
|
|
|
|
|
# Loss that pulls tensors from dim 1 of the input and repeatedly feeds them into the
|
|
# 'subtype' loss.
|
|
class RecurrentLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(RecurrentLoss, self).__init__(opt, env)
|
|
o = opt.copy()
|
|
o['type'] = opt['subtype']
|
|
o['fake'] = '_fake'
|
|
o['real'] = '_real'
|
|
self.loss = create_loss(o, self.env)
|
|
# Use this option to specify a differential weighting scheme for losses inside of the recurrent construct. For
|
|
# example, if later recurrent outputs should contribute more to the loss than earlier ones. When specified,
|
|
# must be a list of weights that exactly aligns with the recurrent list fed to forward().
|
|
self.recurrent_weights = opt['recurrent_weights'] if 'recurrent_weights' in opt.keys() else 1
|
|
|
|
def forward(self, net, state):
|
|
total_loss = 0
|
|
st = state.copy()
|
|
real = state[self.opt['real']]
|
|
for i in range(real.shape[1]):
|
|
st['_real'] = real[:, i]
|
|
st['_fake'] = state[self.opt['fake']][:, i]
|
|
subloss = self.loss(net, st)
|
|
if isinstance(self.recurrent_weights, list):
|
|
subloss = subloss * self.recurrent_weights[i]
|
|
total_loss += subloss
|
|
return total_loss
|
|
|
|
def extra_metrics(self):
|
|
return self.loss.extra_metrics()
|
|
|
|
def clear_metrics(self):
|
|
self.loss.clear_metrics()
|
|
|
|
|
|
# Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss.
|
|
class ForElementLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super(ForElementLoss, self).__init__(opt, env)
|
|
o = opt.copy()
|
|
o['type'] = opt['subtype']
|
|
self.index = opt['index']
|
|
o['fake'] = '_fake'
|
|
o['real'] = '_real'
|
|
self.loss = create_loss(o, self.env)
|
|
|
|
def forward(self, net, state):
|
|
st = state.copy()
|
|
st['_real'] = state[self.opt['real']][:, self.index]
|
|
st['_fake'] = state[self.opt['fake']][:, self.index]
|
|
return self.loss(net, st)
|
|
|
|
def extra_metrics(self):
|
|
return self.loss.extra_metrics()
|
|
|
|
def clear_metrics(self):
|
|
self.loss.clear_metrics()
|
|
|
|
|
|
class StyleGan2DivergenceLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super().__init__(opt, env)
|
|
self.real = opt['real']
|
|
self.fake = opt['fake']
|
|
self.discriminator = opt['discriminator']
|
|
self.for_gen = opt['gen_loss']
|
|
self.gp_frequency = opt['gradient_penalty_frequency']
|
|
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
|
|
|
def forward(self, net, state):
|
|
real_input = state[self.real]
|
|
fake_input = state[self.fake]
|
|
if self.noise != 0:
|
|
fake_input = fake_input + torch.rand_like(fake_input) * self.noise
|
|
real_input = real_input + torch.rand_like(real_input) * self.noise
|
|
|
|
D = self.env['discriminators'][self.discriminator]
|
|
fake = D(fake_input)
|
|
if self.for_gen:
|
|
return fake.mean()
|
|
else:
|
|
real_input.requires_grad_() # <-- Needed to compute gradients on the input.
|
|
real = D(real_input)
|
|
divergence_loss = (F.relu(1 + real) + F.relu(1 - fake)).mean()
|
|
|
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
|
if self.env['step'] % self.gp_frequency == 0:
|
|
from models.archs.stylegan2 import gradient_penalty
|
|
gp = gradient_penalty(real_input, real)
|
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
|
divergence_loss = divergence_loss + gp
|
|
|
|
real_input.requires_grad_(requires_grad=False)
|
|
return divergence_loss
|
|
|
|
|
|
class StyleGan2PathLengthLoss(ConfigurableLoss):
|
|
def __init__(self, opt, env):
|
|
super().__init__(opt, env)
|
|
self.w_styles = opt['w_styles']
|
|
self.gen = opt['gen']
|
|
self.pl_mean = None
|
|
from models.archs.stylegan2 import EMA
|
|
self.pl_length_ma = EMA(.99)
|
|
|
|
def forward(self, net, state):
|
|
w_styles = state[self.w_styles]
|
|
gen = state[self.gen]
|
|
from models.archs.stylegan2 import calc_pl_lengths
|
|
pl_lengths = calc_pl_lengths(w_styles, gen)
|
|
avg_pl_length = np.mean(pl_lengths.detach().cpu().numpy())
|
|
|
|
from models.archs.stylegan2 import is_empty
|
|
if not is_empty(self.pl_mean):
|
|
pl_loss = ((pl_lengths - self.pl_mean) ** 2).mean()
|
|
if not torch.isnan(pl_loss):
|
|
return pl_loss
|
|
else:
|
|
print("Path length loss returned NaN!")
|
|
|
|
self.pl_mean = self.pl_length_ma.update_average(self.pl_mean, avg_pl_length)
|
|
return 0
|