diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index 03c6ac0b..42694438 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -143,7 +143,7 @@ class ResNet(nn.Module): return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1) @register_model -def register_cifar_resnet18(opt_net, opt): +def register_cifar_resnet18_branched(opt_net, opt): """ return a ResNet 18 object """ return ResNet(BasicBlock, [2, 2, 2, 2])