diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index 42694438..806be937 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -140,7 +140,7 @@ class ResNet(nn.Module): for t in self.tails: tailouts.append(t(output)) tailouts = torch.stack(tailouts, dim=0) - return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1) + return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1) @register_model def register_cifar_resnet18_branched(opt_net, opt):