diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 14d995e1..85fcf63a 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -23,6 +23,8 @@ def create_loss(opt_loss, env): return GeometricSimilarityGeneratorLoss(opt_loss, env) elif type == 'translational': return TranslationInvarianceLoss(opt_loss, env) + elif type == 'recursive': + return RecursiveInvarianceLoss(opt_loss, env) else: raise NotImplementedError @@ -328,3 +330,42 @@ class TranslationInvarianceLoss(ConfigurableLoss): else: return self.criterion(fake_shared_output, real_shared_output) + +# Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is +# that the generator's outputs do not change on repeated forward passes. +# The "real" parameter to this loss is the actual output of the generator. +# The "fake" parameter is the expected inputs that should be fed into the generator. 'input_alteration_index' is changed +# so that it feeds the recursive input. +class RecursiveInvarianceLoss(ConfigurableLoss): + def __init__(self, opt, env): + super(RecursiveInvarianceLoss, self).__init__(opt, env) + self.opt = opt + self.generator = opt['generator'] + self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device']) + self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0 + self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None + self.recursive_depth = opt['recursive_depth'] # How many times to recursively feed the output of the generator back into itself + self.downsample_factor = opt['downsample_factor'] # Just 1/opt['scale']. Necessary since this loss doesnt have access to opt['scale']. + assert(self.recursive_depth > 0) + + def forward(self, net, state): + self.metrics = [] + net = self.env['generators'][self.generator] # Get the network from an explicit parameter. + # The parameter is not reliable for generator losses since they can be combined with many networks. + + gen_output = state[self.opt['real']] + recurrent_gen_output = gen_output + + fake = self.opt['fake'].copy() + input = extract_params_from_state(fake, state) + for i in range(self.recursive_depth): + input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest") + recurrent_gen_output = net(*input)[self.gen_output_to_use] + + compare_real = gen_output + compare_fake = recurrent_gen_output + if self.opt['criterion'] == 'cosine': + return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device)) + else: + return self.criterion(compare_real, compare_fake) +