forked from mrq/bitsandbytes-rocm
change typecast behavior
This commit is contained in:
parent
85bf5294a6
commit
e2b523d071
|
@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||||
|
|
||||||
# Cast A to fp16
|
# Cast A to fp16
|
||||||
A_dtype = A.dtype
|
if A.dtype != torch.float16:
|
||||||
if A_dtype != torch.float16:
|
warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16")
|
||||||
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
|
|
||||||
A = A.to(torch.float16)
|
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
if len(A.shape) == 3:
|
if len(A.shape) == 3:
|
||||||
A = A.view(-1, A.shape[-1]).contiguous()
|
A = A.view(-1, A.shape[-1]).contiguous()
|
||||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||||
A, threshold=state.threshold
|
A.to(torch.float16), threshold=state.threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||||
|
@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
if bias is None or bias.dtype == torch.float16:
|
if bias is None or bias.dtype == torch.float16:
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||||
output = output.to(A_dtype)
|
output = output.to(A.dtype)
|
||||||
else: # apply bias separately
|
else: # apply bias separately
|
||||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||||
output = output.to(A_dtype).add_(bias)
|
output = output.to(A.dtype).add_(bias)
|
||||||
|
|
||||||
# 4. Mixed-precision decomposition matmul
|
# 4. Mixed-precision decomposition matmul
|
||||||
if coo_tensorA is not None and subA is not None:
|
if coo_tensorA is not None and subA is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user