Enable lambda visualization
This commit is contained in:
parent
10ec6bda1d
commit
ae4ff4a1e7
|
@ -63,6 +63,7 @@ class SwitchedConv(nn.Module):
|
||||||
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
if selector is None: # A coupler can convert from any input to a selector, so 'None' is allowed.
|
||||||
selector = inp
|
selector = inp
|
||||||
selector = F.softmax(self.coupler(selector), dim=1)
|
selector = F.softmax(self.coupler(selector), dim=1)
|
||||||
|
self.last_select = selector.detach().clone()
|
||||||
out_shape = [s // self.stride for s in inp.shape[2:]]
|
out_shape = [s // self.stride for s in inp.shape[2:]]
|
||||||
if selector.shape[2] != out_shape[0] or selector.shape[3] != out_shape[1]:
|
if selector.shape[2] != out_shape[0] or selector.shape[3] != out_shape[1]:
|
||||||
selector = F.interpolate(selector, size=out_shape, mode="nearest")
|
selector = F.interpolate(selector, size=out_shape, mode="nearest")
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
@ -172,6 +175,7 @@ class VQVAE(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.breadth = breadth
|
||||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth)
|
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, breadth=breadth)
|
||||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth)
|
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, breadth=breadth)
|
||||||
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
self.quantize_conv_t = nn.Conv2d(channel, codebook_dim, 1)
|
||||||
|
@ -200,6 +204,20 @@ class VQVAE(nn.Module):
|
||||||
|
|
||||||
return dec, diff
|
return dec, diff
|
||||||
|
|
||||||
|
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
|
||||||
|
from matplotlib import cm
|
||||||
|
magnitude, indices = torch.topk(attention_out, 3, dim=1)
|
||||||
|
indices = indices.cpu()
|
||||||
|
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
|
||||||
|
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
|
||||||
|
img = img.permute((0, 3, 1, 2))
|
||||||
|
torchvision.utils.save_image(img, output_file)
|
||||||
|
|
||||||
|
def visual_dbg(self, step, path):
|
||||||
|
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
|
||||||
|
for i, c in enumerate(convs):
|
||||||
|
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth)
|
||||||
|
|
||||||
def encode(self, input):
|
def encode(self, input):
|
||||||
enc_b = checkpoint(self.enc_b, input)
|
enc_b = checkpoint(self.enc_b, input)
|
||||||
enc_t = checkpoint(self.enc_t, enc_b)
|
enc_t = checkpoint(self.enc_t, enc_b)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user