Register branched resnet properly
This commit is contained in:
parent
fb405d9ef1
commit
5f0cc65f3b
|
@ -143,7 +143,7 @@ class ResNet(nn.Module):
|
||||||
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
|
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_cifar_resnet18(opt_net, opt):
|
def register_cifar_resnet18_branched(opt_net, opt):
|
||||||
""" return a ResNet 18 object
|
""" return a ResNet 18 object
|
||||||
"""
|
"""
|
||||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
return ResNet(BasicBlock, [2, 2, 2, 2])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user