diff --git a/tests/triton_tests/attn_decomp.py b/tests/triton_tests/attn_decomp.py index 9e8ed28..fa86995 100644 --- a/tests/triton_tests/attn_decomp.py +++ b/tests/triton_tests/attn_decomp.py @@ -97,7 +97,7 @@ class Attention(torch.nn.Module): def forward(self, x, attn_mask = None): q, k, v = self.in_proj_linear(self.ln(x)).chunk(3, dim=-1) - x = torch.compile(torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask) x = self.out_proj(x) return x