Update attention debugger outputting for SSG

This commit is contained in:
James Betker 2020-09-16 13:09:46 -06:00
parent 0b047e5f80
commit 723754c133

View File

@ -9,6 +9,7 @@ import torch.nn.functional as F
from switched_conv_util import save_attention_to_image_rgb
from switched_conv import compute_attention_specificity
import os
import torchvision
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
@ -150,15 +151,18 @@ class SSGr1(nn.Module):
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, 64, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
self.final_hr_conv2 = ConvGnLelu(64, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw_grad, self.conjoin_sw]
self.attentions = None
self.lr = None
self.init_temperature = init_temperature
self.final_temperature_step = 10000
def forward(self, x, embedding):
noise_stds = []
# The attention_maps debugger outputs <x>. Save that here.
self.lr = x.detach().cpu()
x_grad = self.get_g_nopadding(x)
@ -200,9 +204,11 @@ class SSGr1(nn.Module):
(self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp)
if step % 200 == 0:
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
prefix = "attention_map_%i_%%i.png" % (step,)
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
output_path = os.path.join(experiments_path, "attention_maps")
prefix = "amap_%i_a%i_%%i.png"
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
def get_debug_values(self, step):
temp = self.switches[0].switch.temperature