diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 4a407f76..ac6db451 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -528,7 +528,7 @@ class Spsr7(nn.Module): self.final_temperature_step = 10000 self.lr = None - def forward(self, x, ref, ref_center): + def forward(self, x, ref, ref_center, only_return_final_feature_map=False): # The attention_maps debugger outputs . Save that here. self.lr = x.detach().cpu() @@ -551,13 +551,17 @@ class Spsr7(nn.Module): x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity, att_in=(x_grad, ref_embedding)) x_grad = self.grad_lr_conv(x_grad) x_grad = self.grad_lr_conv2(x_grad) - x_grad_out = self.upsample_grad(x_grad) - x_grad_out = self.grad_branch_output_conv(x_grad_out) + if not only_return_final_feature_map: + x_grad_out = self.upsample_grad(x_grad) + x_grad_out = self.grad_branch_output_conv(x_grad_out) x_out = x2 x_out, fea_grad_std = self.conjoin_ref_join(x_out, x_grad) x_out, a4 = self.conjoin_sw(x_out, True, identity=x2, att_in=(x_out, ref_embedding)) x_out = self.final_lr_conv(x_out) + final_feature_map = x_out + if only_return_final_feature_map: + return final_feature_map x_out = checkpoint(self.upsample, x_out) x_out = checkpoint(self.final_hr_conv1, x_out) x_out = self.final_hr_conv2(x_out) @@ -565,7 +569,7 @@ class Spsr7(nn.Module): self.attentions = [a1, a2, a3, a4] self.grad_fea_std = grad_fea_std.detach().cpu() self.fea_grad_std = fea_grad_std.detach().cpu() - return x_grad_out, x_out, s1out, s2out + return x_grad_out, x_out, final_feature_map def set_temperature(self, temp): [sw.set_temperature(temp) for sw in self.switches] diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 7e6f7c30..6ecb4085 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -28,11 +28,16 @@ def create_loss(opt_loss, env): # Converts params to a list of tensors extracted from state. Works with list/tuple params as well as scalars. -def extract_params_from_state(params, state): +def extract_params_from_state(params, state, root=True): if isinstance(params, list) or isinstance(params, tuple): - p = [state[r] for r in params] + p = [extract_params_from_state(r, state, False) for r in params] + elif isinstance(params, str): + p = state[params] else: - p = [state[params]] + p = params + # The root return must always be a list. + if root and not isinstance(p, list): + p = [p] return p @@ -241,7 +246,10 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss): # Undo alteration on HR image upsampled_altered = undo_fn(upsampled_altered) - return self.criterion(state[self.opt['real']], upsampled_altered) + if self.opt['criterion'] == 'cosine': + return self.criterion(state[self.opt['real']], upsampled_altered, torch.ones(1, device=upsampled_altered.device)) + else: + return self.criterion(state[self.opt['real']], upsampled_altered) # Computes a loss created by comparing the output of a generator to the output from the same generator when fed an @@ -280,11 +288,17 @@ class TranslationInvarianceLoss(ConfigurableLoss): trans_output = net(*input) else: trans_output = net(*input) - fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] + if self.gen_output_to_use: + fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] + else: + fake_shared_output = trans_output[:, :, hl:hh, wl:wh] # The "real" input is assumed to always come from the top left tile. gen_output = state[self.opt['real']] real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap] - return self.criterion(fake_shared_output, real_shared_output) + if self.opt['criterion'] == 'cosine': + return self.criterion(fake_shared_output, real_shared_output, torch.ones(1, device=real_shared_output.device)) + else: + return self.criterion(fake_shared_output, real_shared_output)