forked from mrq/DL-Art-School
Modify geometric & translational losses so they can be used with embeddings
This commit is contained in:
parent
e30a1443cd
commit
146a9125f2
|
@ -528,7 +528,7 @@ class Spsr7(nn.Module):
|
||||||
self.final_temperature_step = 10000
|
self.final_temperature_step = 10000
|
||||||
self.lr = None
|
self.lr = None
|
||||||
|
|
||||||
def forward(self, x, ref, ref_center):
|
def forward(self, x, ref, ref_center, only_return_final_feature_map=False):
|
||||||
# The attention_maps debugger outputs <x>. Save that here.
|
# The attention_maps debugger outputs <x>. Save that here.
|
||||||
self.lr = x.detach().cpu()
|
self.lr = x.detach().cpu()
|
||||||
|
|
||||||
|
@ -551,13 +551,17 @@ class Spsr7(nn.Module):
|
||||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding))
|
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding))
|
||||||
x_grad = self.grad_lr_conv(x_grad)
|
x_grad = self.grad_lr_conv(x_grad)
|
||||||
x_grad = self.grad_lr_conv2(x_grad)
|
x_grad = self.grad_lr_conv2(x_grad)
|
||||||
x_grad_out = self.upsample_grad(x_grad)
|
if not only_return_final_feature_map:
|
||||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
x_grad_out = self.upsample_grad(x_grad)
|
||||||
|
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||||
|
|
||||||
x_out = x2
|
x_out = x2
|
||||||
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad)
|
||||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding))
|
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding))
|
||||||
x_out = self.final_lr_conv(x_out)
|
x_out = self.final_lr_conv(x_out)
|
||||||
|
final_feature_map = x_out
|
||||||
|
if only_return_final_feature_map:
|
||||||
|
return final_feature_map
|
||||||
x_out = checkpoint(self.upsample, x_out)
|
x_out = checkpoint(self.upsample, x_out)
|
||||||
x_out = checkpoint(self.final_hr_conv1, x_out)
|
x_out = checkpoint(self.final_hr_conv1, x_out)
|
||||||
x_out = self.final_hr_conv2(x_out)
|
x_out = self.final_hr_conv2(x_out)
|
||||||
|
@ -565,7 +569,7 @@ class Spsr7(nn.Module):
|
||||||
self.attentions = [a1, a2, a3, a4]
|
self.attentions = [a1, a2, a3, a4]
|
||||||
self.grad_fea_std = grad_fea_std.detach().cpu()
|
self.grad_fea_std = grad_fea_std.detach().cpu()
|
||||||
self.fea_grad_std = fea_grad_std.detach().cpu()
|
self.fea_grad_std = fea_grad_std.detach().cpu()
|
||||||
return x_grad_out, x_out, s1out, s2out
|
return x_grad_out, x_out, final_feature_map
|
||||||
|
|
||||||
def set_temperature(self, temp):
|
def set_temperature(self, temp):
|
||||||
[sw.set_temperature(temp) for sw in self.switches]
|
[sw.set_temperature(temp) for sw in self.switches]
|
||||||
|
|
|
@ -28,11 +28,16 @@ def create_loss(opt_loss, env):
|
||||||
|
|
||||||
|
|
||||||
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
|
# Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars.
|
||||||
def extract_params_from_state(params, state):
|
def extract_params_from_state(params, state, root=True):
|
||||||
if isinstance(params, list) or isinstance(params, tuple):
|
if isinstance(params, list) or isinstance(params, tuple):
|
||||||
p = [state[r] for r in params]
|
p = [extract_params_from_state(r, state, False) for r in params]
|
||||||
|
elif isinstance(params, str):
|
||||||
|
p = state[params]
|
||||||
else:
|
else:
|
||||||
p = [state[params]]
|
p = params
|
||||||
|
# The root return must always be a list.
|
||||||
|
if root and not isinstance(p, list):
|
||||||
|
p = [p]
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
@ -241,7 +246,10 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
||||||
# Undo alteration on HR image
|
# Undo alteration on HR image
|
||||||
upsampled_altered = undo_fn(upsampled_altered)
|
upsampled_altered = undo_fn(upsampled_altered)
|
||||||
|
|
||||||
return self.criterion(state[self.opt['real']], upsampled_altered)
|
if self.opt['criterion'] == 'cosine':
|
||||||
|
return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device))
|
||||||
|
else:
|
||||||
|
return self.criterion(state[self.opt['real']], upsampled_altered)
|
||||||
|
|
||||||
|
|
||||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||||
|
@ -280,11 +288,17 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
||||||
trans_output = net(*input)
|
trans_output = net(*input)
|
||||||
else:
|
else:
|
||||||
trans_output = net(*input)
|
trans_output = net(*input)
|
||||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
if self.gen_output_to_use:
|
||||||
|
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||||
|
else:
|
||||||
|
fake_shared_output = trans_output[:, :, hl:hh, wl:wh]
|
||||||
|
|
||||||
# The "real" input is assumed to always come from the top left tile.
|
# The "real" input is assumed to always come from the top left tile.
|
||||||
gen_output = state[self.opt['real']]
|
gen_output = state[self.opt['real']]
|
||||||
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
|
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
|
||||||
|
|
||||||
return self.criterion(fake_shared_output, real_shared_output)
|
if self.opt['criterion'] == 'cosine':
|
||||||
|
return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device))
|
||||||
|
else:
|
||||||
|
return self.criterion(fake_shared_output, real_shared_output)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user