forked from mrq/bitsandbytes-rocm
un-fuse bias
This commit is contained in:
parent
d9ca0ed905
commit
56a074f6dc
|
@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||||
# we apply the fused bias here
|
# we apply the fused bias here
|
||||||
|
|
||||||
fused_bias = bias if bias.dtype == torch.float16 else None
|
if bias is None or bias.dtype == torch.float16:
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||||
if fused_bias is None and bias is not None:
|
output = output.to(A_dtype)
|
||||||
output.add_(bias.to(output.dtype))
|
|
||||||
|
else: # apply bias separately
|
||||||
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||||
|
output = output.to(A_dtype).add_(bias)
|
||||||
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
# 4. Mixed-precision decomposition matmul
|
||||||
if coo_tensorA is not None and subA is not None:
|
if coo_tensorA is not None and subA is not None:
|
||||||
|
@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
ctx.tensor_states = (None, None)
|
ctx.tensor_states = (None, None)
|
||||||
ctx.save_for_backward(None, None)
|
ctx.save_for_backward(None, None)
|
||||||
|
|
||||||
# Cast fp16 output back to A.dtype
|
|
||||||
output = output.to(A_dtype)
|
|
||||||
|
|
||||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||||
return clone_func(output.view(output_shape))
|
return clone_func(output.view(output_shape))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user