diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c2298c8..22fe0c8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -322,7 +322,7 @@ class MatMul8bitLt(torch.autograd.Function): # 1. Quantize A if len(A.shape) == 3: - A = A.view(-1, A.shape[-1]).contiguous() + A = A.reshape(-1, A.shape[-1]) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: