diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 0361ae4a..047aec22 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -391,9 +391,10 @@ class SwitchedSpsrWithRef2(nn.Module): (self.final_temperature_step - step) / self.final_temperature_step) self.set_temperature(temp) if step % 200 == 0: - output_path = os.path.join(experiments_path, "attention_maps", "a%i") - prefix = "attention_map_%i_%%i.png" % (step,) - [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature @@ -523,9 +524,10 @@ class Spsr4(nn.Module): (self.final_temperature_step - step) / self.final_temperature_step) self.set_temperature(temp) if step % 200 == 0: - output_path = os.path.join(experiments_path, "attention_maps", "a%i") - prefix = "attention_map_%i_%%i.png" % (step,) - [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature @@ -654,9 +656,10 @@ class Spsr5(nn.Module): (self.final_temperature_step - step) / self.final_temperature_step) self.set_temperature(temp) if step % 200 == 0: - output_path = os.path.join(experiments_path, "attention_maps", "a%i") - prefix = "attention_map_%i_%%i.png" % (step,) - [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) def get_debug_values(self, step, net_name): temp = self.switches[0].switch.temperature