forked from mrq/bitsandbytes-rocm
un-fuse bias
This commit is contained in:
parent
7facedda38
commit
d9ca0ed905
|
@ -234,8 +234,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
if A_dtype != torch.float16:
|
if A_dtype != torch.float16:
|
||||||
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
|
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
|
||||||
A = A.to(torch.float16)
|
A = A.to(torch.float16)
|
||||||
if bias is not None:
|
|
||||||
bias = bias.to(torch.float16)
|
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
if len(A.shape) == 3:
|
if len(A.shape) == 3:
|
||||||
|
@ -315,7 +313,11 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
C32A, SA = F.transform(CA, "col32")
|
C32A, SA = F.transform(CA, "col32")
|
||||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||||
# we apply the fused bias here
|
# we apply the fused bias here
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
|
||||||
|
fused_bias = bias if bias.dtype == torch.float16 else None
|
||||||
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
|
||||||
|
if fused_bias is None and bias is not None:
|
||||||
|
output.add_(bias.to(output.dtype))
|
||||||
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
# 4. Mixed-precision decomposition matmul
|
||||||
if coo_tensorA is not None and subA is not None:
|
if coo_tensorA is not None and subA is not None:
|
||||||
|
|
|
@ -427,4 +427,4 @@ def test_matmullt(
|
||||||
)
|
)
|
||||||
|
|
||||||
if req_grad[2]:
|
if req_grad[2]:
|
||||||
torch.testing.assert_allclose(gradBias1, gradBias2, atol=0.18, rtol=0.3)
|
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user