Added pre/post calls do quantize_blockwise.
This commit is contained in:
parent
e0e697b150
commit
62a333ac40
|
@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
"""
|
||||
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
if code is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
|
@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
is_on_gpu([code, A, absmax, out, rand])
|
||||
cblocksize = ct.c_int32(blocksize)
|
||||
if rand is not None:
|
||||
is_on_gpu([code, A, out, absmax, rand])
|
||||
assert blocksize==4096
|
||||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
|
@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
else:
|
||||
is_on_gpu([code, A, out, absmax])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
|
@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
# cpu
|
||||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
post_call(A.device)
|
||||
|
||||
return out, (absmax, code)
|
||||
|
||||
|
@ -537,6 +541,7 @@ def dequantize_blockwise(
|
|||
Dequantized tensor (default: float32)
|
||||
"""
|
||||
assert quant_state is not None or absmax is not None
|
||||
device = pre_call(A.device)
|
||||
if code is None and quant_state is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
|
@ -561,6 +566,7 @@ def dequantize_blockwise(
|
|||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
else:
|
||||
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
post_call(A.device)
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user