Fix device error

This commit is contained in:
James Betker 2021-06-05 14:21:32 -06:00
parent 5f0cc65f3b
commit af52751d6b

View File

@ -140,7 +140,7 @@ class ResNet(nn.Module):
for t in self.tails:
tailouts.append(t(output))
tailouts = torch.stack(tailouts, dim=0)
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1)
@register_model
def register_cifar_resnet18_branched(opt_net, opt):