Also debug distribution of switch
This commit is contained in:
parent
44b09e5f20
commit
438217094c
|
@ -208,6 +208,10 @@ class ResNet(nn.Module):
|
||||||
self.in_channels = out_channels * block.expansion
|
self.in_channels = out_channels * block.expansion
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def get_debug_values(self, step, __):
|
||||||
|
logs = {'histogram_switch_usage': self.latest_masks}
|
||||||
|
return logs
|
||||||
|
|
||||||
def forward(self, x, coarse_label, return_selector=False):
|
def forward(self, x, coarse_label, return_selector=False):
|
||||||
output = self.conv1(x)
|
output = self.conv1(x)
|
||||||
output = self.conv2_x(output)
|
output = self.conv2_x(output)
|
||||||
|
@ -221,6 +225,7 @@ class ResNet(nn.Module):
|
||||||
query = self.selector(output).unsqueeze(2)
|
query = self.selector(output).unsqueeze(2)
|
||||||
selector = self.selector_gate(query * keys).squeeze(-1)
|
selector = self.selector_gate(query * keys).squeeze(-1)
|
||||||
selector = self.gate(selector)
|
selector = self.gate(selector)
|
||||||
|
self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,8) == selector).float().argmax(dim=1)
|
||||||
values = self.final_linear(selector.unsqueeze(-1) * keys)
|
values = self.final_linear(selector.unsqueeze(-1) * keys)
|
||||||
|
|
||||||
if return_selector:
|
if return_selector:
|
||||||
|
@ -262,6 +267,7 @@ if __name__ == '__main__':
|
||||||
model = ResNet(BasicBlock, [2,2,2,2])
|
model = ResNet(BasicBlock, [2,2,2,2])
|
||||||
for j in range(10):
|
for j in range(10):
|
||||||
v = model(torch.randn(256,3,32,32), None)
|
v = model(torch.randn(256,3,32,32), None)
|
||||||
|
print(model.get_debug_values(0, None))
|
||||||
print(v.shape)
|
print(v.shape)
|
||||||
l = nn.MSELoss()(v, torch.randn_like(v))
|
l = nn.MSELoss()(v, torch.randn_like(v))
|
||||||
l.backward()
|
l.backward()
|
||||||
|
|
|
@ -4,14 +4,11 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from models.segformer.backbone import backbone50
|
from models.segformer.backbone import backbone50
|
||||||
|
|
||||||
|
|
||||||
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
|
||||||
|
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
||||||
def gather_2d(input, index):
|
def gather_2d(input, index):
|
||||||
b, c, h, w = input.shape
|
b, c, h, w = input.shape
|
||||||
nodim = input.view(b, c, h * w)
|
nodim = input.view(b, c, h * w)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user