Visual dbg in vqvae3hs
This commit is contained in:
parent
f5f91850fd
commit
b0a8fa00bc
|
@ -133,6 +133,20 @@ class VQVAE3HardSwitch(nn.Module):
|
|||
|
||||
return dec, diff
|
||||
|
||||
def save_attention_to_image_rgb(self, output_file, attention_out, attention_size, cmap_discrete_name='viridis'):
|
||||
from matplotlib import cm
|
||||
magnitude, indices = torch.topk(attention_out, 3, dim=1)
|
||||
indices = indices.cpu()
|
||||
colormap = cm.get_cmap(cmap_discrete_name, attention_size)
|
||||
img = torch.tensor(colormap(indices[:, 0, :, :].detach().numpy())) # TODO: use other k's
|
||||
img = img.permute((0, 3, 1, 2))
|
||||
torchvision.utils.save_image(img, output_file)
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
|
||||
for i, c in enumerate(convs):
|
||||
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth)
|
||||
|
||||
def encode(self, input):
|
||||
fea = self.initial_conv(input)
|
||||
enc_b = checkpoint(self.enc_b, fea)
|
||||
|
|
Loading…
Reference in New Issue
Block a user