diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index c2c8903a..bae2ba9c 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -492,9 +492,14 @@ class RelativeQKBias(nn.Module): r = o.unsqueeze(0).repeat(max_positions,1) M = ((-(r-c).abs())+l).clamp(0,l) self.register_buffer('M', M, persistent=False) + self.initted = False def forward(self, n): - return self.emb[self.M[:n, :n]].view(1,n,n) + # Ideally, I'd return this: + # return self.emb[self.M[:n, :n]].view(1,n,n) + # However, indexing operations like this have horrible efficiency on GPUs: https://github.com/pytorch/pytorch/issues/15245 + # So, enter this horrible, equivalent mess: + return torch.gather(self.emb.unsqueeze(-1).repeat(1,n), 0, self.M[:n,:n]).view(1,n,n) class AttentionBlock(nn.Module):