Another fix to anorm

This commit is contained in:
James Betker 2021-06-06 15:09:49 -06:00
parent 9a6991e461
commit 061dbcd458
2 changed files with 2 additions and 2 deletions

View File

@ -193,7 +193,7 @@ def resnet152():
if __name__ == '__main__':
model = ResNet(BasicBlock, [2,2,2,2])
v = model(torch.randn(2,3,32,32), torch.LongTensor([4,7]))
v = model(torch.randn(256,3,32,32), None)
print(v.shape)
l = nn.MSELoss()(v, torch.randn_like(v))
l.backward()

View File

@ -118,7 +118,7 @@ class SwitchNorm(nn.Module):
self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size))
def add_norm_to_buffer(self, x):
flatten_dims = [0] + [k+2 for k in range(len(x)-2)]
flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)]
flat = x.sum(dim=flatten_dims)
norm = flat / torch.mean(flat)