diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 538267b..34b27d9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0) # Cast grad_output to fp16 grad_output_dtype = grad_output.dtype @@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function): -1, grad_output.shape[-1] ).contiguous() - grad_A = grad_B = grad_bias = None - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) @@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function): else: raise Exception('State must contain either CBt or CB matrix for backward') - if req_gradBias: - grad_bias = grad_output.sum(0) - # Cast grad_A back to grad_output_dtype grad_output = grad_output.to(grad_output_dtype)