diff --git a/csrc/kernels.cu b/csrc/kernels.cu index e7e57d7..2e61297 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -64,6 +64,33 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax) } } +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + __device__ unsigned char dQuantizeFP4(float x) { // FP4 with bias of 3 @@ -78,42 +105,79 @@ __device__ unsigned char dQuantizeFP4(float x) // 0b010 = 8 // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +__device__ unsigned char dQuantizeNormal(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + int sign = x < 0 ? 0b1000 : 0b0000; x = fabsf(x); if(x > 3.5f) - { if( x > 7.0f) - { if( x > 10.0f) return 0b0011+sign; else return 0b0010+sign; - } else - { if(x > 5.0f) return 0b101+sign; else return 0b100+sign; - } - } else - { if(x > 1.03125f) - { if(x > 2.5f) return 0b0111+sign; else return 0b0110+sign; - } else - { if(x > 0.03125f) return 0b0001+sign; else return 0b0000+sign; - } - } } template @@ -575,8 +639,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float for(int j = 0; j < NUM_PER_TH/2; j++) { unsigned char packed_fp4 = 0; - packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4; - packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f); + packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); qvals[j] = packed_fp4; } } @@ -639,8 +703,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) { - vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f); - vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333); + //vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f); + //vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333); + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); } } else @@ -656,52 +722,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs } } -//template -//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store) -//{ -// -// const int n_load = n_store/2; -// const int base_idx = (blockIdx.x * TILE_SIZE); -// -// T vals[NUM_PER_TH*2]; -// unsigned char qvals[NUM_PER_TH]; -// -// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE; -// int idx = base_idx + (threadIdx.x*NUM_PER_TH); -// -// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]); -// -// if(valid_items == TILE_SIZE) -// { -// // we do 64 byte loads so we can 128 byte stores -// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; -// } -// else -// { -// #pragma unroll -// for(int j = 0; j < NUM_PER_TH; j++) -// if(idx+j < n_load) -// qvals[j] = A[idx+j]; -// else -// qvals[j] = 0; -// } -// -// -// #pragma unroll NUM_PER_TH -// for(int j = 0; j < NUM_PER_TH; j++) -// { -// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f); -// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f); -// } -// -// -// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; -// reinterpret_cast(A)[idx/16] = reinterpret_cast(local_valC)[j/num_items]; -// -// -//} - - __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) { const unsigned int numThreads = blockDim.x * gridDim.x; diff --git a/tests/test_functional.py b/tests/test_functional.py index a974701..12411e3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2246,8 +2246,10 @@ def test_fp4_quant(): err = (A1 - A2).abs().float() relerr = (err/A1.abs().float()).mean() + idx = err > 1.0 err = err.mean() + assert err.item() < 0.1 assert relerr.item() < 0.28 @@ -2256,7 +2258,7 @@ def test_fp4_compressed_stats(): for blocksize in [128, 64]: errs1 = [] errs2 = [] - for i in range(10): + for i in range(10000): A1 = torch.randn(1024, 1024, device='cuda').half() q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize) q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True) @@ -2268,7 +2270,7 @@ def test_fp4_compressed_stats(): relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() - errs1.append(err.item()) + errs1.append(relerr.item()) assert err.item() < 0.11 assert relerr.item() < 0.28 @@ -2277,7 +2279,7 @@ def test_fp4_compressed_stats(): relerr = (err/(A1.abs().float()+1e-15)).mean() err = err.mean() - errs2.append(err.item()) + errs2.append(relerr.item()) assert err.item() < 0.11 assert relerr.item() < 0.28 @@ -2301,7 +2303,7 @@ def test_bench_fp4_dequant(): #print(max_theoretical_s*1e6) b = torch.randn(128, 1024*12, device='cuda').half() - iters = 5 + iters = 500 torch.cuda.synchronize() t0 = time.time() for i in range(iters):