diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index f66abb7b..4d1e7894 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -208,6 +208,10 @@ class ResNet(nn.Module): self.in_channels = out_channels * block.expansion return nn.Sequential(*layers) + def get_debug_values(self, step, __): + logs = {'histogram_switch_usage': self.latest_masks} + return logs + def forward(self, x, coarse_label, return_selector=False): output = self.conv1(x) output = self.conv2_x(output) @@ -221,6 +225,7 @@ class ResNet(nn.Module): query = self.selector(output).unsqueeze(2) selector = self.selector_gate(query * keys).squeeze(-1) selector = self.gate(selector) + self.latest_masks = (selector.max(dim=1, keepdim=True)[0].repeat(1,8) == selector).float().argmax(dim=1) values = self.final_linear(selector.unsqueeze(-1) * keys) if return_selector: @@ -262,6 +267,7 @@ if __name__ == '__main__': model = ResNet(BasicBlock, [2,2,2,2]) for j in range(10): v = model(torch.randn(256,3,32,32), None) + print(model.get_debug_values(0, None)) print(v.shape) l = nn.MSELoss()(v, torch.randn_like(v)) l.backward() diff --git a/codes/models/segformer/segformer.py b/codes/models/segformer/segformer.py index 04daa225..f9555405 100644 --- a/codes/models/segformer/segformer.py +++ b/codes/models/segformer/segformer.py @@ -4,14 +4,11 @@ import torch import torch.nn as nn import torchvision from tqdm import tqdm - from models.segformer.backbone import backbone50 - - -# torch.gather() which operates as it always fucking should have: pulling indexes from the input. from trainer.networks import register_model +# torch.gather() which operates as it always fucking should have: pulling indexes from the input. def gather_2d(input, index): b, c, h, w = input.shape nodim = input.view(b, c, h * w)