change typecast behavior

This commit is contained in:
justheuristic 2022-09-18 00:07:05 +03:00
parent 85bf5294a6
commit e2b523d071

View File

@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function):
state.outlier_pool = GlobalOutlierPooler.get_instance() state.outlier_pool = GlobalOutlierPooler.get_instance()
# Cast A to fp16 # Cast A to fp16
A_dtype = A.dtype if A.dtype != torch.float16:
if A_dtype != torch.float16: warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16")
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
A = A.to(torch.float16)
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A, threshold=state.threshold A.to(torch.float16), threshold=state.threshold
) )
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and coo_tensorA is not None:
@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function):
if bias is None or bias.dtype == torch.float16: if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A_dtype) output = output.to(A.dtype)
else: # apply bias separately else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A_dtype).add_(bias) output = output.to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if coo_tensorA is not None and subA is not None: