diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c3b0ac6..a115437 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -50,8 +50,9 @@ class GlobalOutlierPooler: class MatMul8bit(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]): - + def forward(ctx, A, B, out=None, quant_type="vector", precision=None): + if precision is None: + precision = [8, 8, 8] if precision[0] != 8: with torch.no_grad(): output = torch.matmul(A, B)