diff --git a/codes/models/switched_conv.py b/codes/models/switched_conv.py index 20c6925d..2e2677b7 100644 --- a/codes/models/switched_conv.py +++ b/codes/models/switched_conv.py @@ -63,6 +63,7 @@ class SwitchedConv(nn.Module): if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed. selector = inp selector = F.softmax(self.coupler(selector), dim=1) + self.last_select = selector.detach().clone() out_shape = [s // self.stride for s in inp.shape[2:]] if selector.shape[2] != out_shape[0] or selector.shape[3] != out_shape[1]: selector = F.interpolate(selector, size=out_shape, mode="nearest") diff --git a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py index 9037f24f..84c5170a 100644 --- a/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose_switched_lambda.py @@ -1,4 +1,7 @@ +import os + import torch +import torchvision from torch import nn from torch.nn import functional as F @@ -172,6 +175,7 @@ class VQVAE(nn.Module): ): super().__init__() + self.breadth = breadth self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth) self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth) self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1) @@ -200,6 +204,20 @@ class VQVAE(nn.Module): return dec, diff + def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'): + from matplotlib import cm + magnitude, indices = torch.topk(attention_out, 3, dim=1) + indices = indices.cpu() + colormap = cm.get_cmap(cmap_discrete_name, attention_size) + img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's + img = img.permute((0, 3, 1, 2)) + torchvision.utils.save_image(img, output_file) + + def visual_dbg(self, step, path): + convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]] + for i, c in enumerate(convs): + self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth) + def encode(self, input): enc_b = checkpoint(self.enc_b, input) enc_t = checkpoint(self.enc_t, enc_b)