Fix device error
This commit is contained in:
parent
5f0cc65f3b
commit
af52751d6b
|
@ -140,7 +140,7 @@ class ResNet(nn.Module):
|
|||
for t in self.tails:
|
||||
tailouts.append(t(output))
|
||||
tailouts = torch.stack(tailouts, dim=0)
|
||||
return (tailouts[coarse_label] * torch.eye(n=bs).view(bs,bs,1)).sum(dim=1)
|
||||
return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1)
|
||||
|
||||
@register_model
|
||||
def register_cifar_resnet18_branched(opt_net, opt):
|
||||
|
|
Loading…
Reference in New Issue
Block a user