diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7a62c1e..e2ca978 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -9,7 +9,7 @@ from bitsandbytes.cuda_setup.main import CUDASetup setup = CUDASetup.get_instance() -if setup.initialized != True: +if not setup.initialized: setup.run_cuda_setup() if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': setup.print_log_stack() diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index cd9573f..6bebd93 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -35,6 +35,9 @@ class CUDASetup: raise RuntimeError("Call get_instance() instead") def generate_instructions(self): + if getattr(self, 'error', False): return + print(self.error) + self.error = True if self.cuda is None: self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') @@ -84,6 +87,7 @@ class CUDASetup: self.has_printed = False self.lib = None self.initialized = False + self.error = False def run_cuda_setup(self): self.initialized = True diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a7c4f..da9e743 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -168,7 +168,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: - bias = 2**(exponent_bits-1)-1 + bias = 2**(exponent_bits-1)+1 + print(bias) for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -176,10 +177,12 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value*2**-(bias-1) + value = value*2**-(bias) else: # normals - value = value*2**-(evalue-bias-2) + print(value, 1) + value = value*2**-(evalue-bias-1) + print(value, 2) values.append(value) if signed: values.append(-value) @@ -193,7 +196,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values.append(0) values.sort() code = torch.Tensor(values) - code /= code.max() + #code /= code.max() return code @@ -587,7 +590,7 @@ def dequantize_blockwise( code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - is_on_gpu([A, out]) + is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: @@ -602,6 +605,116 @@ def dequantize_blockwise( return out +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64) -> Tensor: + """ + Quantize tensor A in blocks of FP4 values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype): + The quantization state to undo the quantization. + """ + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device) + + state = (absmax, input_shape, A.dtype) + + if out is None: + out = torch.zeros(((n+1)//2,), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.float16: + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + return out, state + + +def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) + Tuple of absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + + if quant_state is None: + assert absmax is not None and out is not None + shape = out.shape + dtype = out.dtype + else: + absmax, shape, dtype = quant_state + + + if out is None: + out = torch.empty(shape, dtype=dtype, device=A.device) + + n = out.numel() + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + elif out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + return out + + + + def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: if "dynamic" not in name2qmap: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 08b9b44..a1eec68 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -43,6 +43,79 @@ __device__ float atomicMin(float* address, float val) { return __int_as_float(old); } +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 3.5f) + { + if( x > 7.0f) + { + if( x > 10.0f) + return 0b0011+sign; + else + return 0b0010+sign; + } + else + { + if(x > 5.0f) + return 0b101+sign; + else + return 0b100+sign; + } + } + else + { + if(x > 1.03125f) + { + if(x > 2.5f) + return 0b0111+sign; + else + return 0b0110+sign; + } + else + { + if(x > 0.03125f) + return 0b0001+sign; + else + return 0b0000+sign; + } + } +} + template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) { @@ -427,7 +500,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } } -template +template //__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { @@ -437,13 +510,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; - unsigned char qvals[NUM_PER_TH]; + unsigned char qvals[FP4 ? NUM_PER_TH/2 : NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; typedef cub::BlockLoad LoadT; - typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreChar; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadFloat; @@ -454,8 +527,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; - for(int i = threadIdx.x; i < 256; i+=blockDim.x) - smem_code[i] = code[i]; + if(!FP4) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -495,61 +569,138 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + if(FP4) { - if(!STOCHASTIC) - qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); - else - qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + unsigned char packed_fp4 = 0; + packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4; + packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f); + qvals[j] = packed_fp4; + } + } + else + { + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } } __syncthreads(); - StoreChar(storec).Store(&(out[i]), qvals, valid_items); + StoreChar(storec).Store(&(out[FP4 ? i/2 : i]), qvals, FP4 ? (valid_items+1)/2 : valid_items); } } -template -__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n) +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) { - const int n_full = gridDim.x * BLOCK_SIZE; - int valid_items = 0; - const int base_idx = (blockIdx.x * BLOCK_SIZE); + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); - T vals[NUM_PER_TH]; + T vals[NUM_PER_TH*(FP4 ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreT; + typedef cub::BlockStore StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - //__shared__ float smem_code[256]; - //float local_code[16]; - //if(threadIdx.x < 256) - //smem_code[threadIdx.x] = code[threadIdx.x]; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; - local_abs_max = absmax[i/BLOCK_SIZE]; + if(FP4) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + + if(FP4) + { + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f); + vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333); + } + } + else + { + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + } __syncthreads(); - StoreT(storet).Store(&(out[i]), vals, valid_items); + StoreT(storet).Store(&(out[FP4 ? i*2 : i]), vals, valid_items_store); } } +//template +//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store) +//{ +// +// const int n_load = n_store/2; +// const int base_idx = (blockIdx.x * TILE_SIZE); +// +// T vals[NUM_PER_TH*2]; +// unsigned char qvals[NUM_PER_TH]; +// +// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE; +// int idx = base_idx + (threadIdx.x*NUM_PER_TH); +// +// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]); +// +// if(valid_items == TILE_SIZE) +// { +// // we do 64 byte loads so we can 128 byte stores +// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; +// } +// else +// { +// #pragma unroll +// for(int j = 0; j < NUM_PER_TH; j++) +// if(idx+j < n_load) +// qvals[j] = A[idx+j]; +// else +// qvals[j] = 0; +// } +// +// +// #pragma unroll NUM_PER_TH +// for(int j = 0; j < NUM_PER_TH; j++) +// { +// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f); +// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f); +// } +// +// +// reinterpret_cast(qvals)[0] = reinterpret_cast(A)[idx/8]; +// reinterpret_cast(A)[idx/16] = reinterpret_cast(local_valC)[j/num_items]; +// +// +//} + __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) { @@ -2523,7 +2674,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; if(idx >= colsB){ break; } - //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx); if((idx+num_items < colsB)) { if(BITS == 8) @@ -2543,8 +2693,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o #pragma unroll num_items for(int k = 0; k < num_items; k++) { - //if((float)local_valsB[k] != 0.0) - // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB); if(BITS == 8 && dequant_stats != NULL) // we do texture cache reads (__ldg) on dequant_stats which should be super fast { @@ -2789,38 +2937,42 @@ MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index d90ea13..23aad6c 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -14,8 +14,8 @@ template__global__ void kEstimateQuantiles(T *__restrict__ const A, __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); template __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, diff --git a/csrc/ops.cu b/csrc/ops.cu index e770e10..483d915 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; @@ -58,42 +58,34 @@ template void quantizeBlockwise(float * code, T *A, assert(blocksize == 4096); if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - if(blocksize == 4096) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 2048) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 1024) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 512) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 256) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 128) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 64) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); + int tile_size = FP4 ? 1024 : 512; + + if(FP4) + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + else + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -688,12 +680,16 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 31d4dd8..b3e2424 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -128,8 +128,8 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d8b2290..6a4bb0d 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -75,13 +75,17 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -148,6 +152,11 @@ extern "C" void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_g##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ diff --git a/tests/test_functional.py b/tests/test_functional.py index 69c200a..efdda54 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -152,7 +152,7 @@ def test_dynamic_quantization(): def test_dynamic_blockwise_quantization(): #print('') - for blocksize in [4096, 2048, 1024, 512]: + for blocksize in [4096, 2048, 1024, 512, 256, 128, 64]: diffs = [] reldiffs = [] for i in range(100): @@ -2189,7 +2189,88 @@ def test_bench_dequantization(): torch.cuda.synchronize() t0 = time.time() for i in range(100): - F.dequantize_blockwise(qa, SA, blocksize=2048) + #F.dequantize_blockwise(qa, SA, blocksize=2048) + qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() #print((time.time()-t0)/1e6) + + +def test_fp4_quant(): + vals = list(product([0, 1], repeat=4)) + + code = {} + for bits in vals: + result = 0 + bias = 3 + sign, e1, e2, p1 = bits + idx = sign*8 + e1*4 + e2*2 + p1*1 + sign = -1.0 if sign else 1.0 + exp = e1*2 + e2*1 + if exp == 0: + # sub-normal + if p1 == 0: result = 0 + else: result = sign*0.0625 + else: + # normal + exp = 2**(-exp + bias + 1) + frac = 1.5 if p1 else 1.0 + result = sign*exp*frac + code[idx] = result + + A1 = torch.randn(1024, 1024, device='cuda').half() + qa, SA = F.quantize_fp4(A1, blocksize=64) + A2 = F.dequantize_fp4(qa, SA) + #qa, SA = F.quantize_fp4(A1, blocksize=128) + #A2 = F.dequantize_fp4(qa, SA, blocksize=128) + + #A1 = A1.flatten().sort()[0] + #A2 = A2.flatten().sort()[0] + + #print(A1) + #print(A2) + + err = (A1 - A2).abs().float() + relerr = (err/A1.abs().float()).mean() + err = err.mean() + + print(err, relerr) + + + + + #assert err.item() < 0.1 + #assert relerr.item() < 0.28 + + +def test_bench_fp4_dequant(): + blocksize = 256 + a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + qa, SA = F.quantize_fp4(a, blocksize=blocksize) + + input_size = a.numel()/2 + output_size = a.numel()*2 + num_bytes = input_size+output_size + GB = num_bytes/1e9 + max_theoretical_s = GB/768 + print(max_theoretical_s*1e6) + b = torch.randn(128, 1024*12, device='cuda').half() + + iters = 5 + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + F.dequantize_fp4(qa, SA, blocksize=blocksize) + #b.copy_(a) + torch.cuda.synchronize() + print((time.time()-t0)/iters*1e6) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + torch.matmul(b, a.t()) + torch.cuda.synchronize() + print((time.time()-t0)/iters*1e6) + + +