forked from mrq/bitsandbytes-rocm
change order
This commit is contained in:
parent
e9b87112ee
commit
0de1a4494b
|
@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
grad_A = grad_B = grad_bias = None
|
||||
|
||||
if req_gradBias:
|
||||
# compute grad_bias first before changing grad_output dtype
|
||||
grad_bias = grad_output.sum(0)
|
||||
|
||||
# Cast grad_output to fp16
|
||||
grad_output_dtype = grad_output.dtype
|
||||
|
@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
-1, grad_output.shape[-1]
|
||||
).contiguous()
|
||||
|
||||
grad_A = grad_B = grad_bias = None
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
||||
if req_gradB:
|
||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||
|
@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
else:
|
||||
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||
|
||||
if req_gradBias:
|
||||
grad_bias = grad_output.sum(0)
|
||||
|
||||
# Cast grad_A back to grad_output_dtype
|
||||
grad_output = grad_output.to(grad_output_dtype)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user