forked from mrq/DL-Art-School
Add regular attention to cifar_resnet
This commit is contained in:
parent
16cd92acd5
commit
f5e75602b9
|
@ -121,7 +121,9 @@ class ResNet(nn.Module):
|
|||
|
||||
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
|
||||
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
|
||||
self.tails = nn.ModuleList([ResNetTail(block, num_block, num_classes) for _ in range(num_tails)])
|
||||
self.tails = nn.ModuleList([ResNetTail(block, num_block, 256) for _ in range(num_tails)])
|
||||
self.selector = ResNetTail(block, num_block, num_tails)
|
||||
self.final_linear = nn.Linear(256, 100)
|
||||
|
||||
def _make_layer(self, block, out_channels, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
|
@ -135,12 +137,20 @@ class ResNet(nn.Module):
|
|||
output = self.conv1(x)
|
||||
output = self.conv2_x(output)
|
||||
output = self.conv3_x(output)
|
||||
bs = output.shape[0]
|
||||
tailouts = []
|
||||
|
||||
keys = []
|
||||
for t in self.tails:
|
||||
tailouts.append(t(output))
|
||||
tailouts = torch.stack(tailouts, dim=0)
|
||||
return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1)
|
||||
keys.append(t(output))
|
||||
keys = torch.stack(keys, dim=1)
|
||||
|
||||
query = self.selector(output).unsqueeze(2)
|
||||
attn = torch.nn.functional.softmax(query * keys, dim=1)
|
||||
values = self.final_linear(attn * keys)
|
||||
|
||||
return values.sum(dim=1)
|
||||
|
||||
#bs = output.shape[0]
|
||||
#return (tailouts[coarse_label] * torch.eye(n=bs, device=x.device).view(bs,bs,1)).sum(dim=1)
|
||||
|
||||
@register_model
|
||||
def register_cifar_resnet18_branched(opt_net, opt):
|
||||
|
|
Loading…
Reference in New Issue
Block a user