un-fuse bias

This commit is contained in:
justheuristic 2022-09-17 23:46:37 +03:00
parent d9ca0ed905
commit 56a074f6dc

View File

@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function):
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
fused_bias = bias if bias.dtype == torch.float16 else None
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
if fused_bias is None and bias is not None:
output.add_(bias.to(output.dtype))
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A_dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A_dtype).add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
# Cast fp16 output back to A.dtype
output = output.to(A_dtype)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape))