Modify geometric & translational losses so they can be used with embeddings

This commit is contained in:
James Betker 2020-10-02 20:40:13 -06:00
parent e30a1443cd
commit 146a9125f2
2 changed files with 28 additions and 10 deletions

View File

@ -528,7 +528,7 @@ class Spsr7(nn.Module):
self.final_temperature_step = 10000 self.final_temperature_step = 10000
self.lr = None self.lr = None
def forward(self, x, ref, ref_center): def forward(self, x, ref, ref_center, only_return_final_feature_map=False):
# The attention_maps debugger outputs <x>. Save that here. # The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu() self.lr = x.detach().cpu()
@ -551,13 +551,17 @@ class Spsr7(nn.Module):
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding)) x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding))
x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv(x_grad)
x_grad = self.grad_lr_conv2(x_grad) x_grad = self.grad_lr_conv2(x_grad)
x_grad_out = self.upsample_grad(x_grad) if not only_return_final_feature_map:
x_grad_out = self.grad_branch_output_conv(x_grad_out) x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out = x2 x_out = x2
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding)) x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding))
x_out = self.final_lr_conv(x_out) x_out = self.final_lr_conv(x_out)
final_feature_map = x_out
if only_return_final_feature_map:
return final_feature_map
x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.upsample, x_out)
x_out = checkpoint(self.final_hr_conv1, x_out) x_out = checkpoint(self.final_hr_conv1, x_out)
x_out = self.final_hr_conv2(x_out) x_out = self.final_hr_conv2(x_out)
@ -565,7 +569,7 @@ class Spsr7(nn.Module):
self.attentions = [a1, a2, a3, a4] self.attentions = [a1, a2, a3, a4]
self.grad_fea_std = grad_fea_std.detach().cpu() self.grad_fea_std = grad_fea_std.detach().cpu()
self.fea_grad_std = fea_grad_std.detach().cpu() self.fea_grad_std = fea_grad_std.detach().cpu()
return x_grad_out, x_out, s1out, s2out return x_grad_out, x_out, final_feature_map
def set_temperature(self, temp): def set_temperature(self, temp):
[sw.set_temperature(temp) for sw in self.switches] [sw.set_temperature(temp) for sw in self.switches]

View File

@ -28,11 +28,16 @@ def create_loss(opt_loss, env):
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars. # Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
def extract_params_from_state(params, state): def extract_params_from_state(params, state, root=True):
if isinstance(params, list) or isinstance(params, tuple): if isinstance(params, list) or isinstance(params, tuple):
p = [state[r] for r in params] p = [extract_params_from_state(r, state, False) for r in params]
elif isinstance(params, str):
p = state[params]
else: else:
p = [state[params]] p = params
# The root return must always be a list.
if root and not isinstance(p, list):
p = [p]
return p return p
@ -241,7 +246,10 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
# Undo alteration on HR image # Undo alteration on HR image
upsampled_altered = undo_fn(upsampled_altered) upsampled_altered = undo_fn(upsampled_altered)
return self.criterion(state[self.opt['real']], upsampled_altered) if self.opt['criterion'] == 'cosine':
return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device))
else:
return self.criterion(state[self.opt['real']], upsampled_altered)
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
@ -280,11 +288,17 @@ class TranslationInvarianceLoss(ConfigurableLoss):
trans_output = net(*input) trans_output = net(*input)
else: else:
trans_output = net(*input) trans_output = net(*input)
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] if self.gen_output_to_use:
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
else:
fake_shared_output = trans_output[:, :, hl:hh, wl:wh]
# The "real" input is assumed to always come from the top left tile. # The "real" input is assumed to always come from the top left tile.
gen_output = state[self.opt['real']] gen_output = state[self.opt['real']]
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap] real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
return self.criterion(fake_shared_output, real_shared_output) if self.opt['criterion'] == 'cosine':
return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device))
else:
return self.criterion(fake_shared_output, real_shared_output)