pull/2/head
James Betker 2022-07-21 00:43:03 +07:00
parent b92ff8de78
commit 02ebda42f2
1 changed files with 6 additions and 1 deletions

@ -492,9 +492,14 @@ class RelativeQKBias(nn.Module):
r = o.unsqueeze(0).repeat(max_positions,1)
M = ((-(r-c).abs())+l).clamp(0,l)
self.register_buffer('M', M, persistent=False)
self.initted = False
def forward(self, n):
return self.emb[self.M[:n, :n]].view(1,n,n)
# Ideally, I'd return this:
# return self.emb[self.M[:n, :n]].view(1,n,n)
# However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245
# So, enter this horrible, equivalent mess:
return torch.gather(self.emb.unsqueeze(-1).repeat(1,n), 0, self.M[:n,:n]).view(1,n,n)
class AttentionBlock(nn.Module):