diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index 89f76e6b..3fd3487d 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from switched_conv_util import save_attention_to_image_rgb from switched_conv import compute_attention_specificity import os +import torchvision # VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation @@ -150,15 +151,18 @@ class SSGr1(nn.Module): transform_count=self.transformation_counts, init_temp=init_temperature, add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True) self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True) - self.upsample = nn.Sequential(*[UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) - self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False) + self.upsample = nn.Sequential(*[UpconvBlock(nf, 64, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)]) + self.final_hr_conv2 = ConvGnLelu(64, out_nc, kernel_size=3, norm=False, activation=False, bias=False) self.switches = [self.sw1, self.sw_grad, self.conjoin_sw] self.attentions = None + self.lr = None self.init_temperature = init_temperature self.final_temperature_step = 10000 def forward(self, x, embedding): noise_stds = [] + # The attention_maps debugger outputs . Save that here. + self.lr = x.detach().cpu() x_grad = self.get_g_nopadding(x) @@ -200,9 +204,11 @@ class SSGr1(nn.Module): (self.final_temperature_step - step) / self.final_temperature_step) self.set_temperature(temp) if step % 200 == 0: - output_path = os.path.join(experiments_path, "attention_maps", "a%i") - prefix = "attention_map_%i_%%i.png" % (step,) - [save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))] + output_path = os.path.join(experiments_path, "attention_maps") + prefix = "amap_%i_a%i_%%i.png" + [save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))] + torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,))) + def get_debug_values(self, step): temp = self.switches[0].switch.temperature