un-fuse bias

This commit is contained in:
justheuristic 2022-09-17 23:44:28 +03:00
parent 7facedda38
commit d9ca0ed905
2 changed files with 6 additions and 4 deletions

View File

@ -234,8 +234,6 @@ class MatMul8bitLt(torch.autograd.Function):
if A_dtype != torch.float16: if A_dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
A = A.to(torch.float16) A = A.to(torch.float16)
if bias is not None:
bias = bias.to(torch.float16)
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
@ -315,7 +313,11 @@ class MatMul8bitLt(torch.autograd.Function):
C32A, SA = F.transform(CA, "col32") C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here # we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
fused_bias = bias if bias.dtype == torch.float16 else None
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
if fused_bias is None and bias is not None:
output.add_(bias.to(output.dtype))
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if coo_tensorA is not None and subA is not None:

View File

@ -427,4 +427,4 @@ def test_matmullt(
) )
if req_grad[2]: if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2, atol=0.18, rtol=0.3) torch.testing.assert_allclose(gradBias1, gradBias2)