diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index f2fdb7d..f8403cf 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -323,7 +323,7 @@ class MatMul8bitLt(torch.autograd.Function): # 1. Quantize A if len(A.shape) == 3: - A = A.view(-1, A.shape[-1]).contiguous() + A = A.reshape(-1, A.shape[-1]) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: