Modify geometric & translational losses so they can be used with embeddings
This commit is contained in:
parent
e30a1443cd
commit
146a9125f2
|
@ -528,7 +528,7 @@ class Spsr7(nn.Module):
|
|||
self.final_temperature_step = 10000
|
||||
self.lr = None
|
||||
|
||||
def forward(self, x, ref, ref_center):
|
||||
def forward(self, x, ref, ref_center, only_return_final_feature_map=False):
|
||||
# The attention_maps debugger outputs <x>. Save that here.
|
||||
self.lr = x.detach().cpu()
|
||||
|
||||
|
@ -551,13 +551,17 @@ class Spsr7(nn.Module):
|
|||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding))
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
x_grad = self.grad_lr_conv2(x_grad)
|
||||
x_grad_out = self.upsample_grad(x_grad)
|
||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||
if not only_return_final_feature_map:
|
||||
x_grad_out = self.upsample_grad(x_grad)
|
||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||
|
||||
x_out = x2
|
||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding))
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
final_feature_map = x_out
|
||||
if only_return_final_feature_map:
|
||||
return final_feature_map
|
||||
x_out = checkpoint(self.upsample, x_out)
|
||||
x_out = checkpoint(self.final_hr_conv1, x_out)
|
||||
x_out = self.final_hr_conv2(x_out)
|
||||
|
@ -565,7 +569,7 @@ class Spsr7(nn.Module):
|
|||
self.attentions = [a1, a2, a3, a4]
|
||||
self.grad_fea_std = grad_fea_std.detach().cpu()
|
||||
self.fea_grad_std = fea_grad_std.detach().cpu()
|
||||
return x_grad_out, x_out, s1out, s2out
|
||||
return x_grad_out, x_out, final_feature_map
|
||||
|
||||
def set_temperature(self, temp):
|
||||
[sw.set_temperature(temp) for sw in self.switches]
|
||||
|
|
|
@ -28,11 +28,16 @@ def create_loss(opt_loss, env):
|
|||
|
||||
|
||||
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
|
||||
def extract_params_from_state(params, state):
|
||||
def extract_params_from_state(params, state, root=True):
|
||||
if isinstance(params, list) or isinstance(params, tuple):
|
||||
p = [state[r] for r in params]
|
||||
p = [extract_params_from_state(r, state, False) for r in params]
|
||||
elif isinstance(params, str):
|
||||
p = state[params]
|
||||
else:
|
||||
p = [state[params]]
|
||||
p = params
|
||||
# The root return must always be a list.
|
||||
if root and not isinstance(p, list):
|
||||
p = [p]
|
||||
return p
|
||||
|
||||
|
||||
|
@ -241,7 +246,10 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
|||
# Undo alteration on HR image
|
||||
upsampled_altered = undo_fn(upsampled_altered)
|
||||
|
||||
return self.criterion(state[self.opt['real']], upsampled_altered)
|
||||
if self.opt['criterion'] == 'cosine':
|
||||
return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device))
|
||||
else:
|
||||
return self.criterion(state[self.opt['real']], upsampled_altered)
|
||||
|
||||
|
||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||
|
@ -280,11 +288,17 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
trans_output = net(*input)
|
||||
else:
|
||||
trans_output = net(*input)
|
||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||
if self.gen_output_to_use:
|
||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||
else:
|
||||
fake_shared_output = trans_output[:, :, hl:hh, wl:wh]
|
||||
|
||||
# The "real" input is assumed to always come from the top left tile.
|
||||
gen_output = state[self.opt['real']]
|
||||
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
|
||||
|
||||
return self.criterion(fake_shared_output, real_shared_output)
|
||||
if self.opt['criterion'] == 'cosine':
|
||||
return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device))
|
||||
else:
|
||||
return self.criterion(fake_shared_output, real_shared_output)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user