Added pre/post calls do quantize_blockwise.

This commit is contained in:
Tim Dettmers 2022-11-06 17:17:51 -08:00
parent e0e697b150
commit 62a333ac40

View File

@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
"""
prev_device = pre_call(A.device)
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
is_on_gpu([code, A, absmax, out, rand])
cblocksize = ct.c_int32(blocksize)
if rand is not None:
is_on_gpu([code, A, out, absmax, rand])
assert blocksize==4096
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16:
@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
# cpu
assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
post_call(A.device)
return out, (absmax, code)
@ -537,6 +541,7 @@ def dequantize_blockwise(
Dequantized tensor (default: float32)
"""
assert quant_state is not None or absmax is not None
device = pre_call(A.device)
if code is None and quant_state is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@ -561,6 +566,7 @@ def dequantize_blockwise(
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
post_call(A.device)
return out