forked from mrq/DL-Art-School
Another fix to anorm
This commit is contained in:
parent
9a6991e461
commit
061dbcd458
|
@ -193,7 +193,7 @@ def resnet152():
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = ResNet(BasicBlock, [2,2,2,2])
|
||||
v = model(torch.randn(2,3,32,32), torch.LongTensor([4,7]))
|
||||
v = model(torch.randn(256,3,32,32), None)
|
||||
print(v.shape)
|
||||
l = nn.MSELoss()(v, torch.randn_like(v))
|
||||
l.backward()
|
||||
|
|
|
@ -118,7 +118,7 @@ class SwitchNorm(nn.Module):
|
|||
self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size))
|
||||
|
||||
def add_norm_to_buffer(self, x):
|
||||
flatten_dims = [0] + [k+2 for k in range(len(x)-2)]
|
||||
flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)]
|
||||
flat = x.sum(dim=flatten_dims)
|
||||
norm = flat / torch.mean(flat)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user