Add recursive loss

This commit is contained in:
James Betker 2020-10-04 20:48:15 -06:00
parent ffd069fd97
commit 13f97e1e97

View File

@ -23,6 +23,8 @@ def create_loss(opt_loss, env):
return GeometricSimilarityGeneratorLoss(opt_loss, env) return GeometricSimilarityGeneratorLoss(opt_loss, env)
elif type == 'translational': elif type == 'translational':
return TranslationInvarianceLoss(opt_loss, env) return TranslationInvarianceLoss(opt_loss, env)
elif type == 'recursive':
return RecursiveInvarianceLoss(opt_loss, env)
else: else:
raise NotImplementedError raise NotImplementedError
@ -328,3 +330,42 @@ class TranslationInvarianceLoss(ConfigurableLoss):
else: else:
return self.criterion(fake_shared_output, real_shared_output) return self.criterion(fake_shared_output, real_shared_output)
# Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is
# that the generator's outputs do not change on repeated forward passes.
# The "real" parameter to this loss is the actual output of the generator.
# The "fake" parameter is the expected inputs that should be fed into the generator. 'input_alteration_index' is changed
# so that it feeds the recursive input.
class RecursiveInvarianceLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(RecursiveInvarianceLoss, self).__init__(opt, env)
self.opt = opt
self.generator = opt['generator']
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
self.recursive_depth = opt['recursive_depth'] # How many times to recursively feed the output of the generator back into itself
self.downsample_factor = opt['downsample_factor'] # Just 1/opt['scale']. Necessary since this loss doesnt have access to opt['scale'].
assert(self.recursive_depth > 0)
def forward(self, net, state):
self.metrics = []
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
# The <net> parameter is not reliable for generator losses since they can be combined with many networks.
gen_output = state[self.opt['real']]
recurrent_gen_output = gen_output
fake = self.opt['fake'].copy()
input = extract_params_from_state(fake, state)
for i in range(self.recursive_depth):
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
recurrent_gen_output = net(*input)[self.gen_output_to_use]
compare_real = gen_output
compare_fake = recurrent_gen_output
if self.opt['criterion'] == 'cosine':
return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device))
else:
return self.criterion(compare_real, compare_fake)