Update attention debugger outputting for SSG
This commit is contained in:
parent
0b047e5f80
commit
723754c133
|
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|||
from switched_conv_util import save_attention_to_image_rgb
|
||||
from switched_conv import compute_attention_specificity
|
||||
import os
|
||||
import torchvision
|
||||
|
||||
|
||||
# VGG-style layer with Conv(stride2)->BN->Activation->Conv->BN->Activation
|
||||
|
@ -150,15 +151,18 @@ class SSGr1(nn.Module):
|
|||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False, feed_transforms_into_multiplexer=True)
|
||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=True, bias=True)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf // 2, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
|
||||
self.final_hr_conv2 = ConvGnLelu(nf // 2, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, 64, block=ConvGnLelu, norm=False, activation=True, bias=True) for _ in range(n_upscale)])
|
||||
self.final_hr_conv2 = ConvGnLelu(64, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||
self.switches = [self.sw1, self.sw_grad, self.conjoin_sw]
|
||||
self.attentions = None
|
||||
self.lr = None
|
||||
self.init_temperature = init_temperature
|
||||
self.final_temperature_step = 10000
|
||||
|
||||
def forward(self, x, embedding):
|
||||
noise_stds = []
|
||||
# The attention_maps debugger outputs <x>. Save that here.
|
||||
self.lr = x.detach().cpu()
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
|
||||
|
@ -200,9 +204,11 @@ class SSGr1(nn.Module):
|
|||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
if step % 200 == 0:
|
||||
output_path = os.path.join(experiments_path, "attention_maps", "a%i")
|
||||
prefix = "attention_map_%i_%%i.png" % (step,)
|
||||
[save_attention_to_image_rgb(output_path % (i,), self.attentions[i], self.transformation_counts, prefix, step) for i in range(len(self.attentions))]
|
||||
output_path = os.path.join(experiments_path, "attention_maps")
|
||||
prefix = "amap_%i_a%i_%%i.png"
|
||||
[save_attention_to_image_rgb(output_path, self.attentions[i], self.transformation_counts, prefix % (step, i), step, output_mag=False) for i in range(len(self.attentions))]
|
||||
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", "amap_%i_base_image.png" % (step,)))
|
||||
|
||||
|
||||
def get_debug_values(self, step):
|
||||
temp = self.switches[0].switch.temperature
|
||||
|
|
Loading…
Reference in New Issue
Block a user