Added pre/post calls do quantize_blockwise.
This commit is contained in:
parent
e0e697b150
commit
62a333ac40
|
@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
prev_device = pre_call(A.device)
|
||||||
if code is None:
|
if code is None:
|
||||||
if "dynamic" not in name2qmap:
|
if "dynamic" not in name2qmap:
|
||||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||||
|
@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
||||||
is_on_gpu([code, A, absmax, out, rand])
|
is_on_gpu([code, A, absmax, out, rand])
|
||||||
cblocksize = ct.c_int32(blocksize)
|
cblocksize = ct.c_int32(blocksize)
|
||||||
if rand is not None:
|
if rand is not None:
|
||||||
|
is_on_gpu([code, A, out, absmax, rand])
|
||||||
assert blocksize==4096
|
assert blocksize==4096
|
||||||
assert rand.numel() >= 1024
|
assert rand.numel() >= 1024
|
||||||
rand_offset = random.randint(0, 1023)
|
rand_offset = random.randint(0, 1023)
|
||||||
|
@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
else:
|
else:
|
||||||
|
is_on_gpu([code, A, out, absmax])
|
||||||
if A.dtype == torch.float32:
|
if A.dtype == torch.float32:
|
||||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||||
elif A.dtype == torch.float16:
|
elif A.dtype == torch.float16:
|
||||||
|
@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
||||||
# cpu
|
# cpu
|
||||||
assert rand is None
|
assert rand is None
|
||||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||||
|
post_call(A.device)
|
||||||
|
|
||||||
return out, (absmax, code)
|
return out, (absmax, code)
|
||||||
|
|
||||||
|
@ -537,6 +541,7 @@ def dequantize_blockwise(
|
||||||
Dequantized tensor (default: float32)
|
Dequantized tensor (default: float32)
|
||||||
"""
|
"""
|
||||||
assert quant_state is not None or absmax is not None
|
assert quant_state is not None or absmax is not None
|
||||||
|
device = pre_call(A.device)
|
||||||
if code is None and quant_state is None:
|
if code is None and quant_state is None:
|
||||||
if "dynamic" not in name2qmap:
|
if "dynamic" not in name2qmap:
|
||||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||||
|
@ -561,6 +566,7 @@ def dequantize_blockwise(
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
else:
|
else:
|
||||||
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||||
|
post_call(A.device)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user