diff --git a/codes/models/vqvae/vqvae_3_hardswitch.py b/codes/models/vqvae/vqvae_3_hardswitch.py index 1222a5ba..958eda98 100644 --- a/codes/models/vqvae/vqvae_3_hardswitch.py +++ b/codes/models/vqvae/vqvae_3_hardswitch.py @@ -1,6 +1,10 @@ +import os +from time import time + import torch import torchvision from torch import nn +from tqdm import tqdm from models.switched_conv.switched_conv_hard_routing import SwitchedConvHardRouting, \ convert_conv_net_state_dict_to_switched_conv @@ -146,7 +150,21 @@ class VQVAE3HardSwitch(nn.Module): def visual_dbg(self, step, path): convs = [self.dec.blocks[-1].conv, self.dec_t.blocks[-1].conv, self.enc_b.blocks[-4], self.enc_t.blocks[-4]] for i, c in enumerate(convs): - self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, self.breadth) + self.save_attention_to_image_rgb(os.path.join(path, "%i_selector_%i.png" % (step, i+1)), c.last_select, 16) + + def get_debug_values(self, step, __): + switched_convs = [('enc_b_blk2', self.enc_b.blocks[2]), + ('enc_b_blk4', self.enc_b.blocks[4]), + ('enc_t_blk2', self.enc_t.blocks[2]), + ('dec_t_blk0', self.dec_t.blocks[0]), + ('dec_t_blk-1', self.dec_t.blocks[-1].conv), + ('dec_blk0', self.dec.blocks[0]), + ('dec_blk-1', self.dec.blocks[-1].conv), + ('dec_blk-3', self.dec.blocks[-3].conv)] + logs = {} + for name, swc in switched_convs: + logs[f'{name}_histogram_switch_usage'] = swc.latest_masks + return logs def encode(self, input): fea = self.initial_conv(input) @@ -205,7 +223,23 @@ def register_vqvae3_hard_switch(opt_net, opt): return VQVAE3HardSwitch(**kw) +def performance_test(): + net = VQVAE3HardSwitch().to('cuda') + loss = nn.L1Loss() + opt = torch.optim.Adam(net.parameters(), lr=1e-4) + started = time() + for j in tqdm(range(10)): + inp = torch.rand((8, 3, 256, 256), device='cuda') + res = net(inp)[0] + l = loss(res, inp) + l.backward() + opt.step() + net.zero_grad() + print("Elapsed: ", (time()-started)) + + if __name__ == '__main__': #v = VQVAE3HardSwitch() #print(v(torch.randn(1,3,128,128))[0].shape) - convert_weights("../../../experiments/test_vqvae3.pth") \ No newline at end of file + #convert_weights("../../../experiments/test_vqvae3.pth") + performance_test()