change order

This commit is contained in:
justheuristic 2022-09-17 23:53:49 +03:00
parent e9b87112ee
commit 0de1a4494b

View File

@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0)
# Cast grad_output to fp16
grad_output_dtype = grad_output.dtype
@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function):
-1, grad_output.shape[-1]
).contiguous()
grad_A = grad_B = grad_bias = None
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function):
else:
raise Exception('State must contain either CBt or CB matrix for backward')
if req_gradBias:
grad_bias = grad_output.sum(0)
# Cast grad_A back to grad_output_dtype
grad_output = grad_output.to(grad_output_dtype)