diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index fdc7ddaf..2c3a6eb4 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -7,7 +7,8 @@ from collections import OrderedDict from models.archs.arch_util import ConvBnLelu, ConvGnSilu, ExpansionBlock from models.archs.RRDBNet_arch import ResidualDenseBlock_5C, RRDB from models.archs.spinenet_arch import SpineNet -from switched_conv_util import save_attention_to_image +from switched_conv_util import save_attention_to_image_rgb +import os class MultiConvBlock(nn.Module): @@ -228,7 +229,9 @@ class ConfigurableSwitchedResidualGenerator2(nn.Module): temp = 1 / temp self.set_temperature(temp) if step % 50 == 0: - [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] def get_debug_values(self, step): temp = self.switches[0].switch.temperature @@ -335,7 +338,9 @@ class ConfigurableSwitchedResidualGenerator4(nn.Module): temp = 1 / temp self.set_temperature(temp) if step % 50 == 0: - [save_attention_to_image(experiments_path, self.attentions[i], self.transformation_counts, step, "a%i" % (i+1,), l_mult=10) for i in range(len(self.attentions))] + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] def get_debug_values(self, step): temp = self.switches[0].switch.temperature @@ -426,7 +431,9 @@ class ConfigurableSwitchedResidualGenerator3(nn.Module): temp = 1 / temp self.set_temperature(temp) if step % 50 == 0: - save_attention_to_image(experiments_path, self.attentions[0], self.transformation_counts, step, "a%i" % (1,), l_mult=10) + output_path = os.path.join(experiments_path, "attention_maps", "a%i") + prefix = "attention_map_%i_%%i.png" % (step,) + [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] def get_debug_values(self, step): temp = self.switch.switch.temperature