Fix device error
This commit is contained in:
parent
5f0cc65f3b
commit
af52751d6b
|
@ -140,7 +140,7 @@ class ResNet(nn.Module):
|
||||||
for t in self.tails:
|
for t in self.tails:
|
||||||
tailouts.append(t(output))
|
tailouts.append(t(output))
|
||||||
tailouts = torch.stack(tailouts, dim=0)
|
tailouts = torch.stack(tailouts, dim=0)
|
||||||
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
|
return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1)
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_cifar_resnet18_branched(opt_net, opt):
|
def register_cifar_resnet18_branched(opt_net, opt):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user