diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index a5a6c750..5eda08f8 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -112,7 +112,7 @@ class ResNetTail(nn.Module): class DropoutNorm(SwitchNorm): - def __init__(self, group_size, dropout_rate, accumulator_size=256): + def __init__(self, group_size, dropout_rate, accumulator_size=256, eps=1e-6): super().__init__(group_size, accumulator_size) self.accumulator_desired_size = accumulator_size self.group_size = group_size @@ -120,6 +120,7 @@ class DropoutNorm(SwitchNorm): self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu')) self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu')) self.register_buffer("accumulator", torch.zeros(accumulator_size, group_size)) + self.eps = eps def add_norm_to_buffer(self, x): flatten_dims = [0] + [k+2 for k in range(len(x.shape)-2)] @@ -165,7 +166,7 @@ class DropoutNorm(SwitchNorm): # Ensure that there is always at least one switch left un-dropped out fix_blank = (drop.sum(dim=1, keepdim=True) == 0).repeat(1, br) drop = drop.logical_or(fix_blank) - x = drop * x + x = drop * x + ((not drop) * x * self.eps) return x