forked from mrq/DL-Art-School
Also debug distribution of switch
This commit is contained in:
parent
44b09e5f20
commit
438217094c
|
@ -208,6 +208,10 @@ class ResNet(nn.Module):
|
|||
self.in_channels = out_channels * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
logs = {'histogram_switch_usage': self.latest_masks}
|
||||
return logs
|
||||
|
||||
def forward(self, x, coarse_label, return_selector=False):
|
||||
output = self.conv1(x)
|
||||
output = self.conv2_x(output)
|
||||
|
@ -221,6 +225,7 @@ class ResNet(nn.Module):
|
|||
query = self.selector(output).unsqueeze(2)
|
||||
selector = self.selector_gate(query * keys).squeeze(-1)
|
||||
selector = self.gate(selector)
|
||||
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,8) == selector).float().argmax(dim=1)
|
||||
values = self.final_linear(selector.unsqueeze(-1) * keys)
|
||||
|
||||
if return_selector:
|
||||
|
@ -262,6 +267,7 @@ if __name__ == '__main__':
|
|||
model = ResNet(BasicBlock, [2,2,2,2])
|
||||
for j in range(10):
|
||||
v = model(torch.randn(256,3,32,32), None)
|
||||
print(model.get_debug_values(0, None))
|
||||
print(v.shape)
|
||||
l = nn.MSELoss()(v, torch.randn_like(v))
|
||||
l.backward()
|
||||
|
|
|
@ -4,14 +4,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.segformer.backbone import backbone50
|
||||
|
||||
|
||||
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
||||
def gather_2d(input, index):
|
||||
b, c, h, w = input.shape
|
||||
nodim = input.view(b, c, h * w)
|
||||
|
|
Loading…
Reference in New Issue
Block a user