From f5e75602b9ab78ffb4a08464ecc002426c9a479e Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Jun 2021 21:34:07 -0600 Subject: [PATCH] Add regular attention to cifar_resnet --- .../classifiers/cifar_resnet_branched.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/codes/models/classifiers/cifar_resnet_branched.py b/codes/models/classifiers/cifar_resnet_branched.py index 806be937..293ed46f 100644 --- a/codes/models/classifiers/cifar_resnet_branched.py +++ b/codes/models/classifiers/cifar_resnet_branched.py @@ -121,7 +121,9 @@ class ResNet(nn.Module): self.conv2_x = self._make_layer(block, 64, num_block[0], 1) self.conv3_x = self._make_layer(block, 128, num_block[1], 2) - self.tails = nn.ModuleList([ResNetTail(block, num_block, num_classes) for _ in range(num_tails)]) + self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)]) + self.selector = ResNetTail(block, num_block, num_tails) + self.final_linear = nn.Linear(256, 100) def _make_layer(self, block, out_channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) @@ -135,12 +137,20 @@ class ResNet(nn.Module): output = self.conv1(x) output = self.conv2_x(output) output = self.conv3_x(output) - bs = output.shape[0] - tailouts = [] + + keys = [] for t in self.tails: - tailouts.append(t(output)) - tailouts = torch.stack(tailouts, dim=0) - return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1) + keys.append(t(output)) + keys = torch.stack(keys, dim=1) + + query = self.selector(output).unsqueeze(2) + attn = torch.nn.functional.softmax(query * keys, dim=1) + values = self.final_linear(attn * keys) + + return values.sum(dim=1) + + #bs = output.shape[0] + #return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1) @register_model def register_cifar_resnet18_branched(opt_net, opt):