forked from mrq/bitsandbytes-rocm
un-fuse bias
This commit is contained in:
parent
d9ca0ed905
commit
56a074f6dc
|
@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
# we apply the fused bias here
|
||||
|
||||
fused_bias = bias if bias.dtype == torch.float16 else None
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
|
||||
if fused_bias is None and bias is not None:
|
||||
output.add_(bias.to(output.dtype))
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
output = output.to(A_dtype)
|
||||
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
output = output.to(A_dtype).add_(bias)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if coo_tensorA is not None and subA is not None:
|
||||
|
@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
# Cast fp16 output back to A.dtype
|
||||
output = output.to(A_dtype)
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
return clone_func(output.view(output_shape))
|
||||
|
|
Loading…
Reference in New Issue
Block a user