Register branched resnet properly

This commit is contained in:
James Betker 2021-06-05 14:19:03 -06:00
parent fb405d9ef1
commit 5f0cc65f3b

View File

@ -143,7 +143,7 @@ class ResNet(nn.Module):
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
@register_model
def register_cifar_resnet18(opt_net, opt):
def register_cifar_resnet18_branched(opt_net, opt):
""" return a ResNet 18 object
"""
return ResNet(BasicBlock, [2, 2, 2, 2])