Add regular attention to cifar_resnet

This commit is contained in:
James Betker 2021-06-05 21:34:07 -06:00
parent 16cd92acd5
commit f5e75602b9

View File

@ -121,7 +121,9 @@ class ResNet(nn.Module):
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.tails = nn.ModuleList([ResNetTail(block, num_block, num_classes) for _ in range(num_tails)])
self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)])
self.selector = ResNetTail(block, num_block, num_tails)
self.final_linear = nn.Linear(256, 100)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
@ -135,12 +137,20 @@ class ResNet(nn.Module):
output = self.conv1(x)
output = self.conv2_x(output)
output = self.conv3_x(output)
bs = output.shape[0]
tailouts = []
keys = []
for t in self.tails:
tailouts.append(t(output))
tailouts = torch.stack(tailouts, dim=0)
return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1)
keys.append(t(output))
keys = torch.stack(keys, dim=1)
query = self.selector(output).unsqueeze(2)
attn = torch.nn.functional.softmax(query * keys, dim=1)
values = self.final_linear(attn * keys)
return values.sum(dim=1)
#bs = output.shape[0]
#return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1)
@register_model
def register_cifar_resnet18_branched(opt_net, opt):