From 061dbcd4586005a3b495ce420f80d81e1760fb51 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 6 Jun 2021 15:09:49 -0600 Subject: [PATCH] Another fix to anorm --- codes/models/classifiers/cifar_resnet_branched.py | 2 +- codes/models/switched_conv/switched_conv_hard_routing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index e314b22e..dd80ef44 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -193,7 +193,7 @@ def resnet152(): if __name__ == '__main__': model = ResNet(BasicBlock, [2,2,2,2]) - v = model(torch.randn(2,3,32,32), torch.LongTensor([4,7])) + v = model(torch.randn(256,3,32,32), None) print(v.shape) l = nn.MSELoss()(v, torch.randn_like(v)) l.backward() diff --git a/codes/models/switched_conv/switched_conv_hard_routing.py b/codes/models/switched_conv/switched_conv_hard_routing.py index 2b2c73b3..daad9022 100644 --- a/codes/models/switched_conv/switched_conv_hard_routing.py +++ b/codes/models/switched_conv/switched_conv_hard_routing.py @@ -118,7 +118,7 @@ class SwitchNorm(nn.Module): self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size)) def add_norm_to_buffer(self, x): - flatten_dims = [0] + [k+2 for k in range(len(x)-2)] + flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)] flat = x.sum(dim=flatten_dims) norm = flat / torch.mean(flat)