From 2489d819c5009e88a1572809a2f3306dace84051 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 14 Feb 2023 13:55:17 -0800 Subject: [PATCH] Added more blocksizes for stochastic rounding; fixed dequant blocksize. --- bitsandbytes/autograd/_functions.py | 6 +++--- bitsandbytes/functional.py | 5 ++--- csrc/kernels.cu | 12 ++++++++++++ csrc/ops.cu | 14 ++++++-------- csrc/pythonInterface.c | 8 ++++---- tests/test_functional.py | 16 ++++++++++------ 6 files changed, 37 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c2b8773..b8b2dbc 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -413,10 +413,10 @@ class MatMulFP8(torch.autograd.Function): # 2. MatmulnN cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024) - fp8A = F.dequantize_blockwise(cA, state).to(A.dtype) + fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype) cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024) - fp8B = F.dequantize_blockwise(cB, state).to(B.dtype) + fp8B = F.dequantize_blockwise(cB, state, blocksize=1024).to(B.dtype) output = torch.matmul(fp8A, fp8B) @@ -443,7 +443,7 @@ class MatMulFP8(torch.autograd.Function): grad_A, grad_B = None, None cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024) - fp8out = F.dequantize_blockwise(cgrad_out, state).to(grad_output.dtype) + fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=1024).to(grad_output.dtype) # Cast grad_output to fp16 if len(grad_output.shape) == 3: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 371f85c..dbc2828 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -508,13 +508,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra code = code.to(A.device) if rand is not None: is_on_gpu([code, A, out, absmax, rand]) - assert blocksize==4096 assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), cblocksize, ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index b32b39c..99224ad 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2797,16 +2797,28 @@ template __global__ void kQuantizeBlockwise(float * code, half template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index e770e10..9e01588 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -54,23 +54,21 @@ template void quantizeBlockwise(float * code, T *A, { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - if(STOCHASTIC == 1) - assert(blocksize == 4096); if(blocksize == 4096) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d8b2290..d1055cd 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -77,8 +77,8 @@ void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, blocksize, n); } +void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, blocksize, n); } void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } @@ -142,8 +142,8 @@ extern "C" void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } - void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } + void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, blocksize, n); } + void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, blocksize, n); } void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 70fa4d0..5a24aeb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -188,21 +188,25 @@ def test_dynamic_blockwise_quantization(): #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) -def test_dynamic_blockwise_stochastic_quantization(): + +@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) +def test_dynamic_blockwise_stochastic_quantization(blocksize): diffs = [] reldiffs = [] rand = torch.rand(1024).cuda() + err = 0 for i in range(100): A1 = torch.randn(1024, 1024, device="cuda") - C1, S1 = F.quantize_blockwise(A1, rand=rand) - C2, S2 = F.quantize_blockwise(A1) + C1, S1 = F.quantize_blockwise(A1, rand=rand, blocksize=blocksize) + C2, S2 = F.quantize_blockwise(A1, blocksize=blocksize) + A2 = F.dequantize_blockwise(C1, S1, blocksize=blocksize) + err += (A1-A2).abs().mean().item()/100 # a maximunm distance of quantized values of 1 torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) fraction_smaller = (C1 < C2).float().sum() / C1.numel() fraction_larger = (C1 > C2).float().sum() / C1.numel() - torch.testing.assert_allclose( - fraction_larger, fraction_smaller, atol=0.01, rtol=0 - ) + torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0) + assert err < 0.019 @pytest.mark.parametrize(