Fixed bug in cpu quant; faster GPU dequant.

This commit is contained in:
Tim Dettmers 2022-11-07 18:06:18 -08:00
parent 62a333ac40
commit 08fa2e7b01
5 changed files with 44 additions and 25 deletions

View File

@ -94,7 +94,6 @@ class CUDASetup(object):
else:
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
self.lib = ct.cdll.LoadLibrary(binary_path)
print(self.lib)
except Exception as ex:
self.add_log_entry(str(ex))
self.print_log_stack()

View File

@ -458,16 +458,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
"""
prev_device = pre_call(A.device)
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if absmax is None:
n = A.numel()
blocksize = (blocksize if A.device.type == 'cuda' else 4096)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
@ -477,8 +474,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
if A.device.type != 'cpu':
assert blocksize in [4096, 2048, 1024, 512]
is_on_gpu([code, A, absmax, out, rand])
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
if rand is not None:
is_on_gpu([code, A, out, absmax, rand])
assert blocksize==4096
@ -498,11 +496,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else:
# cpu
code = code.cpu()
assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
post_call(A.device)
return out, (absmax, code)
@ -541,32 +540,35 @@ def dequantize_blockwise(
Dequantized tensor (default: float32)
"""
assert quant_state is not None or absmax is not None
device = pre_call(A.device)
if code is None and quant_state is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
code = code.to(A.device)
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
if quant_state is None:
quant_state = (absmax, code)
else:
absmax, code = quant_state
if A.device.type != 'cpu':
device = pre_call(A.device)
code = code.to(A.device)
if blocksize not in [2048, 4096, 1024, 512]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else:
code = code.cpu()
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
post_call(A.device)
return out

View File

@ -510,7 +510,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
}
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
__shared__ float smem_code[256];
//__shared__ float smem_code[256];
//float local_code[16];
if(threadIdx.x < 256)
smem_code[threadIdx.x] = code[threadIdx.x];
//if(threadIdx.x < 256)
//smem_code[threadIdx.x] = code[threadIdx.x];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
__syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
vals[j] = smem_code[qvals[j]]*local_abs_max;
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
__syncthreads();
StoreT(storet).Store(&(out[i]), vals, valid_items);
@ -2798,14 +2800,14 @@ template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, flo
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);

View File

@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,

View File

@ -2166,3 +2166,19 @@ def test_kbit_quantile_estimation():
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
err = torch.abs(val1-val2).mean()
assert err < 0.035
def test_bench_dequantization():
a = torch.rand(1024, 1024, device='cuda').half()
qa, SA = F.quantize_blockwise(a)
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
#print(max_theoretical_mu)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
F.dequantize_blockwise(qa, SA, blocksize=2048)
torch.cuda.synchronize()
#print((time.time()-t0)/1e6)