Also debug distribution of switch

This commit is contained in:
James Betker 2021-06-07 15:36:07 -06:00
parent 44b09e5f20
commit 438217094c
2 changed files with 7 additions and 4 deletions

View File

@ -208,6 +208,10 @@ class ResNet(nn.Module):
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def get_debug_values(self, step, __):
logs = {'histogram_switch_usage': self.latest_masks}
return logs
def forward(self, x, coarse_label, return_selector=False):
output = self.conv1(x)
output = self.conv2_x(output)
@ -221,6 +225,7 @@ class ResNet(nn.Module):
query = self.selector(output).unsqueeze(2)
selector = self.selector_gate(query * keys).squeeze(-1)
selector = self.gate(selector)
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,8) == selector).float().argmax(dim=1)
values = self.final_linear(selector.unsqueeze(-1) * keys)
if return_selector:
@ -262,6 +267,7 @@ if __name__ == '__main__':
model = ResNet(BasicBlock, [2,2,2,2])
for j in range(10):
v = model(torch.randn(256,3,32,32), None)
print(model.get_debug_values(0, None))
print(v.shape)
l = nn.MSELoss()(v, torch.randn_like(v))
l.backward()

View File

@ -4,14 +4,11 @@ import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
from models.segformer.backbone import backbone50
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
from trainer.networks import register_model
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
def gather_2d(input, index):
b, c, h, w = input.shape
nodim = input.view(b, c, h * w)