Fixed bug in cpu quant; faster GPU dequant.
This commit is contained in:
parent
62a333ac40
commit
08fa2e7b01
|
@ -94,7 +94,6 @@ class CUDASetup(object):
|
|||
else:
|
||||
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
|
||||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
print(self.lib)
|
||||
except Exception as ex:
|
||||
self.add_log_entry(str(ex))
|
||||
self.print_log_stack()
|
||||
|
|
|
@ -458,16 +458,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
"""
|
||||
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
if code is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
code = code.to(A.device)
|
||||
|
||||
if absmax is None:
|
||||
n = A.numel()
|
||||
blocksize = (blocksize if A.device.type == 'cuda' else 4096)
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
|
@ -477,8 +474,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
|
||||
if A.device.type != 'cpu':
|
||||
assert blocksize in [4096, 2048, 1024, 512]
|
||||
is_on_gpu([code, A, absmax, out, rand])
|
||||
cblocksize = ct.c_int32(blocksize)
|
||||
prev_device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if rand is not None:
|
||||
is_on_gpu([code, A, out, absmax, rand])
|
||||
assert blocksize==4096
|
||||
|
@ -498,11 +496,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
# cpu
|
||||
code = code.cpu()
|
||||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
post_call(A.device)
|
||||
|
||||
return out, (absmax, code)
|
||||
|
||||
|
@ -541,32 +540,35 @@ def dequantize_blockwise(
|
|||
Dequantized tensor (default: float32)
|
||||
"""
|
||||
assert quant_state is not None or absmax is not None
|
||||
device = pre_call(A.device)
|
||||
if code is None and quant_state is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
code = code.to(A.device)
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code)
|
||||
else:
|
||||
absmax, code = quant_state
|
||||
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if blocksize not in [2048, 4096, 1024, 512]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
|
||||
is_on_gpu([A, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
code = code.cpu()
|
||||
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
post_call(A.device)
|
||||
|
||||
return out
|
||||
|
||||
|
|
|
@ -510,7 +510,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
|
||||
{
|
||||
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
|
@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
|
|||
|
||||
__shared__ typename LoadChar::TempStorage loadchar;
|
||||
__shared__ typename StoreT::TempStorage storet;
|
||||
__shared__ float smem_code[256];
|
||||
//__shared__ float smem_code[256];
|
||||
//float local_code[16];
|
||||
|
||||
if(threadIdx.x < 256)
|
||||
smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
//if(threadIdx.x < 256)
|
||||
//smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
|
||||
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
|
||||
{
|
||||
|
@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
|
|||
__syncthreads();
|
||||
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
|
||||
|
||||
// load code through read-only cache via __ldg
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
vals[j] = smem_code[qvals[j]]*local_abs_max;
|
||||
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
|
||||
|
||||
__syncthreads();
|
||||
StoreT(storet).Store(&(out[i]), vals, valid_items);
|
||||
|
@ -2798,14 +2800,14 @@ template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, flo
|
|||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
|
|
|
@ -2166,3 +2166,19 @@ def test_kbit_quantile_estimation():
|
|||
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
|
||||
err = torch.abs(val1-val2).mean()
|
||||
assert err < 0.035
|
||||
|
||||
|
||||
def test_bench_dequantization():
|
||||
a = torch.rand(1024, 1024, device='cuda').half()
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
|
||||
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
||||
#print(max_theoretical_mu)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
F.dequantize_blockwise(qa, SA, blocksize=2048)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/1e6)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user