forked from mrq/bitsandbytes-rocm
change order
This commit is contained in:
parent
0de1a4494b
commit
647c976a74
|
@ -316,10 +316,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
delayed_bias = None
|
||||
output = output.to(A_dtype)
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
delayed_bias = bias
|
||||
output = output.to(A_dtype).add_(bias)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if coo_tensorA is not None and subA is not None:
|
||||
|
@ -340,9 +340,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
output = output.to(A_dtype)
|
||||
if delayed_bias is not None:
|
||||
output.add_(delayed_bias)
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
return clone_func(output.view(output_shape))
|
||||
|
|
Loading…
Reference in New Issue
Block a user