Modify geometric & translational losses so they can be used with embeddings

This commit is contained in:
James Betker 2020-10-02 20:40:13 -06:00
parent e30a1443cd
commit 146a9125f2
2 changed files with 28 additions and 10 deletions

View File

@ -528,7 +528,7 @@ class Spsr7(nn.Module):
self.final_temperature_step = 10000
self.lr = None
def forward(self, x, ref, ref_center):
def forward(self, x, ref, ref_center, only_return_final_feature_map=False):
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
@ -551,6 +551,7 @@ class Spsr7(nn.Module):
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding))
x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_lr_conv2(x_grad)
if not only_return_final_feature_map:
x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out)
@ -558,6 +559,9 @@ class Spsr7(nn.Module):
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding))
x_out = self.final_lr_conv(x_out)
final_feature_map = x_out
if only_return_final_feature_map:
return final_feature_map
x_out = checkpoint(self.upsample, x_out)
x_out = checkpoint(self.final_hr_conv1, x_out)
x_out = self.final_hr_conv2(x_out)
@ -565,7 +569,7 @@ class Spsr7(nn.Module):
self.attentions = [a1, a2, a3, a4]
self.grad_fea_std = grad_fea_std.detach().cpu()
self.fea_grad_std = fea_grad_std.detach().cpu()
return x_grad_out, x_out, s1out, s2out
return x_grad_out, x_out, final_feature_map
def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches]

View File

@ -28,11 +28,16 @@ def create_loss(opt_loss, env):
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
def extract_params_from_state(params, state):
def extract_params_from_state(params, state, root=True):
if isinstance(params, list) or isinstance(params, tuple):
p = [state[r] for r in params]
p = [extract_params_from_state(r, state, False) for r in params]
elif isinstance(params, str):
p = state[params]
else:
p = [state[params]]
p = params
# The root return must always be a list.
if root and not isinstance(p, list):
p = [p]
return p
@ -241,6 +246,9 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
# Undo alteration on HR image
upsampled_altered = undo_fn(upsampled_altered)
if self.opt['criterion'] == 'cosine':
return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device))
else:
return self.criterion(state[self.opt['real']], upsampled_altered)
@ -280,11 +288,17 @@ class TranslationInvarianceLoss(ConfigurableLoss):
trans_output = net(*input)
else:
trans_output = net(*input)
if self.gen_output_to_use:
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
else:
fake_shared_output = trans_output[:, :, hl:hh, wl:wh]
# The "real" input is assumed to always come from the top left tile.
gen_output = state[self.opt['real']]
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
if self.opt['criterion'] == 'cosine':
return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device))
else:
return self.criterion(fake_shared_output, real_shared_output)