Add get_debug_values for vqvae_3_hardswitch
This commit is contained in:
parent
1405ff06b8
commit
b980028ca8
|
@ -1,6 +1,10 @@
|
||||||
|
import os
|
||||||
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \
|
||||||
convert_conv_net_state_dict_to_switched_conv
|
convert_conv_net_state_dict_to_switched_conv
|
||||||
|
@ -146,7 +150,21 @@ class VQVAE3HardSwitch(nn.Module):
|
||||||
def visual_dbg(self, step, path):
|
def visual_dbg(self, step, path):
|
||||||
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
|
convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]]
|
||||||
for i, c in enumerate(convs):
|
for i, c in enumerate(convs):
|
||||||
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth)
|
self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, 16)
|
||||||
|
|
||||||
|
def get_debug_values(self, step, __):
|
||||||
|
switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]),
|
||||||
|
('enc_b_blk4', self.enc_b.blocks[4]),
|
||||||
|
('enc_t_blk2', self.enc_t.blocks[2]),
|
||||||
|
('dec_t_blk0', self.dec_t.blocks[0]),
|
||||||
|
('dec_t_blk-1', self.dec_t.blocks[-1].conv),
|
||||||
|
('dec_blk0', self.dec.blocks[0]),
|
||||||
|
('dec_blk-1', self.dec.blocks[-1].conv),
|
||||||
|
('dec_blk-3', self.dec.blocks[-3].conv)]
|
||||||
|
logs = {}
|
||||||
|
for name, swc in switched_convs:
|
||||||
|
logs[f'{name}_histogram_switch_usage'] = swc.latest_masks
|
||||||
|
return logs
|
||||||
|
|
||||||
def encode(self, input):
|
def encode(self, input):
|
||||||
fea = self.initial_conv(input)
|
fea = self.initial_conv(input)
|
||||||
|
@ -205,7 +223,23 @@ def register_vqvae3_hard_switch(opt_net, opt):
|
||||||
return VQVAE3HardSwitch(**kw)
|
return VQVAE3HardSwitch(**kw)
|
||||||
|
|
||||||
|
|
||||||
|
def performance_test():
|
||||||
|
net = VQVAE3HardSwitch().to('cuda')
|
||||||
|
loss = nn.L1Loss()
|
||||||
|
opt = torch.optim.Adam(net.parameters(), lr=1e-4)
|
||||||
|
started = time()
|
||||||
|
for j in tqdm(range(10)):
|
||||||
|
inp = torch.rand((8, 3, 256, 256), device='cuda')
|
||||||
|
res = net(inp)[0]
|
||||||
|
l = loss(res, inp)
|
||||||
|
l.backward()
|
||||||
|
opt.step()
|
||||||
|
net.zero_grad()
|
||||||
|
print("Elapsed: ", (time()-started))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#v = VQVAE3HardSwitch()
|
#v = VQVAE3HardSwitch()
|
||||||
#print(v(torch.randn(1,3,128,128))[0].shape)
|
#print(v(torch.randn(1,3,128,128))[0].shape)
|
||||||
convert_weights("../../../experiments/test_vqvae3.pth")
|
#convert_weights("../../../experiments/test_vqvae3.pth")
|
||||||
|
performance_test()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user