forked from mrq/DL-Art-School
#yolo
This commit is contained in:
parent
b92ff8de78
commit
02ebda42f2
|
@ -492,9 +492,14 @@ class RelativeQKBias(nn.Module):
|
|||
r = o.unsqueeze(0).repeat(max_positions,1)
|
||||
M = ((-(r-c).abs())+l).clamp(0,l)
|
||||
self.register_buffer('M', M, persistent=False)
|
||||
self.initted = False
|
||||
|
||||
def forward(self, n):
|
||||
return self.emb[self.M[:n, :n]].view(1,n,n)
|
||||
# Ideally, I'd return this:
|
||||
# return self.emb[self.M[:n, :n]].view(1,n,n)
|
||||
# However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245
|
||||
# So, enter this horrible, equivalent mess:
|
||||
return torch.gather(self.emb.unsqueeze(-1).repeat(1,n), 0, self.M[:n,:n]).view(1,n,n)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user