diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 03949de..5a83dfd 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -294,7 +294,7 @@ class MatMul8bitLt(torch.autograd.Function):
                 (outliers * state.SCB.view(-1, 1) / 127.0)
                 .t()
                 .contiguous()
-                .to(B.dtype)
+                .to(A.dtype)
             )
             CA[:, state.idx.long()] = 0
             CAt[:, state.idx.long()] = 0