This commit is contained in:
Mitchell Wortsman 2023-03-29 19:04:53 +00:00
parent 5f3d9ada8d
commit b373034e31

View File

@ -97,7 +97,7 @@ class Attention(torch.nn.Module):
def forward(self, x, attn_mask = None):
q, k, v = self.in_proj_linear(self.ln(x)).chunk(3, dim=-1)
x = torch.compile(torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask))
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
x = self.out_proj(x)
return x