|
|
|
@ -492,9 +492,14 @@ class RelativeQKBias(nn.Module):
|
|
|
|
|
r = o.unsqueeze(0).repeat(max_positions,1)
|
|
|
|
|
M = ((-(r-c).abs())+l).clamp(0,l)
|
|
|
|
|
self.register_buffer('M', M, persistent=False)
|
|
|
|
|
self.initted = False
|
|
|
|
|
|
|
|
|
|
def forward(self, n):
|
|
|
|
|
return self.emb[self.M[:n, :n]].view(1,n,n)
|
|
|
|
|
# Ideally, I'd return this:
|
|
|
|
|
# return self.emb[self.M[:n, :n]].view(1,n,n)
|
|
|
|
|
# However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245
|
|
|
|
|
# So, enter this horrible, equivalent mess:
|
|
|
|
|
return torch.gather(self.emb.unsqueeze(-1).repeat(1,n), 0, self.M[:n,:n]).view(1,n,n)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
|
|
|