Add recursive loss
This commit is contained in:
parent
ffd069fd97
commit
13f97e1e97
|
@ -23,6 +23,8 @@ def create_loss(opt_loss, env):
|
|||
return GeometricSimilarityGeneratorLoss(opt_loss, env)
|
||||
elif type == 'translational':
|
||||
return TranslationInvarianceLoss(opt_loss, env)
|
||||
elif type == 'recursive':
|
||||
return RecursiveInvarianceLoss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -328,3 +330,42 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
else:
|
||||
return self.criterion(fake_shared_output, real_shared_output)
|
||||
|
||||
|
||||
# Computes a loss repeatedly feeding the generator downsampled inputs created from its outputs. The expectation is
|
||||
# that the generator's outputs do not change on repeated forward passes.
|
||||
# The "real" parameter to this loss is the actual output of the generator.
|
||||
# The "fake" parameter is the expected inputs that should be fed into the generator. 'input_alteration_index' is changed
|
||||
# so that it feeds the recursive input.
|
||||
class RecursiveInvarianceLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(RecursiveInvarianceLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.generator = opt['generator']
|
||||
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||
self.gen_input_for_alteration = opt['input_alteration_index'] if 'input_alteration_index' in opt.keys() else 0
|
||||
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
||||
self.recursive_depth = opt['recursive_depth'] # How many times to recursively feed the output of the generator back into itself
|
||||
self.downsample_factor = opt['downsample_factor'] # Just 1/opt['scale']. Necessary since this loss doesnt have access to opt['scale'].
|
||||
assert(self.recursive_depth > 0)
|
||||
|
||||
def forward(self, net, state):
|
||||
self.metrics = []
|
||||
net = self.env['generators'][self.generator] # Get the network from an explicit parameter.
|
||||
# The <net> parameter is not reliable for generator losses since they can be combined with many networks.
|
||||
|
||||
gen_output = state[self.opt['real']]
|
||||
recurrent_gen_output = gen_output
|
||||
|
||||
fake = self.opt['fake'].copy()
|
||||
input = extract_params_from_state(fake, state)
|
||||
for i in range(self.recursive_depth):
|
||||
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
|
||||
recurrent_gen_output = net(*input)[self.gen_output_to_use]
|
||||
|
||||
compare_real = gen_output
|
||||
compare_fake = recurrent_gen_output
|
||||
if self.opt['criterion'] == 'cosine':
|
||||
return self.criterion(compare_real, compare_fake, torch.ones(1, device=compare_real.device))
|
||||
else:
|
||||
return self.criterion(compare_real, compare_fake)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user