Added fp4 quant/dequant and dequant optimizations.
This commit is contained in:
parent
0f5c394870
commit
3ac5840c03
|
@ -9,7 +9,7 @@ from bitsandbytes.cuda_setup.main import CUDASetup
|
|||
|
||||
|
||||
setup = CUDASetup.get_instance()
|
||||
if setup.initialized != True:
|
||||
if not setup.initialized:
|
||||
setup.run_cuda_setup()
|
||||
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
|
||||
setup.print_log_stack()
|
||||
|
|
|
@ -35,6 +35,9 @@ class CUDASetup:
|
|||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def generate_instructions(self):
|
||||
if getattr(self, 'error', False): return
|
||||
print(self.error)
|
||||
self.error = True
|
||||
if self.cuda is None:
|
||||
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
|
||||
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
|
||||
|
@ -84,6 +87,7 @@ class CUDASetup:
|
|||
self.has_printed = False
|
||||
self.lib = None
|
||||
self.initialized = False
|
||||
self.error = False
|
||||
|
||||
def run_cuda_setup(self):
|
||||
self.initialized = True
|
||||
|
|
|
@ -168,7 +168,8 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
values = []
|
||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||
#for ev in evalues:
|
||||
bias = 2**(exponent_bits-1)-1
|
||||
bias = 2**(exponent_bits-1)+1
|
||||
print(bias)
|
||||
for evalue in range(2**(exponent_bits)):
|
||||
for bit_pattern in lst:
|
||||
value = (1 if evalue != 0 else 0)
|
||||
|
@ -176,10 +177,12 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
value += pval*(2**-(i+1))
|
||||
if evalue == 0:
|
||||
# subnormals
|
||||
value = value*2**-(bias-1)
|
||||
value = value*2**-(bias)
|
||||
else:
|
||||
# normals
|
||||
value = value*2**-(evalue-bias-2)
|
||||
print(value, 1)
|
||||
value = value*2**-(evalue-bias-1)
|
||||
print(value, 2)
|
||||
values.append(value)
|
||||
if signed:
|
||||
values.append(-value)
|
||||
|
@ -193,7 +196,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
values.append(0)
|
||||
values.sort()
|
||||
code = torch.Tensor(values)
|
||||
code /= code.max()
|
||||
#code /= code.max()
|
||||
|
||||
return code
|
||||
|
||||
|
@ -587,7 +590,7 @@ def dequantize_blockwise(
|
|||
code = code.to(A.device)
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
is_on_gpu([A, out])
|
||||
is_on_gpu([A, absmax, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
|
@ -602,6 +605,116 @@ def dequantize_blockwise(
|
|||
return out
|
||||
|
||||
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64) -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of FP4 values.
|
||||
|
||||
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : torch.Tensor
|
||||
The input tensor.
|
||||
absmax : torch.Tensor
|
||||
The absmax values.
|
||||
out : torch.Tensor
|
||||
The output tensor (8-bit).
|
||||
blocksize : int
|
||||
The blocksize used in quantization.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor:
|
||||
The 8-bit tensor with packed 4-bit values.
|
||||
tuple(torch.Tensor, torch.Size, torch.dtype):
|
||||
The quantization state to undo the quantization.
|
||||
"""
|
||||
if A.device.type != 'cuda':
|
||||
raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
|
||||
|
||||
n = A.numel()
|
||||
input_shape = A.shape
|
||||
|
||||
if absmax is None:
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
|
||||
state = (absmax, input_shape, A.dtype)
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros(((n+1)//2,), dtype=torch.uint8, device=A.device)
|
||||
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
is_on_gpu([A, out, absmax])
|
||||
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
||||
return out, state
|
||||
|
||||
|
||||
def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
|
||||
"""
|
||||
Dequantizes FP4 blockwise quantized values.
|
||||
|
||||
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : torch.Tensor
|
||||
The input 8-bit tensor (packed 4-bit values).
|
||||
quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
|
||||
Tuple of absmax values, original tensor shape and original dtype.
|
||||
absmax : torch.Tensor
|
||||
The absmax values.
|
||||
out : torch.Tensor
|
||||
Dequantized output tensor.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor:
|
||||
Dequantized tensor.
|
||||
"""
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
|
||||
if quant_state is None:
|
||||
assert absmax is not None and out is not None
|
||||
shape = out.shape
|
||||
dtype = out.dtype
|
||||
else:
|
||||
absmax, shape, dtype = quant_state
|
||||
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=A.device)
|
||||
|
||||
n = out.numel()
|
||||
|
||||
device = pre_call(A.device)
|
||||
is_on_gpu([A, absmax, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
|
||||
if code is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
|
|
284
csrc/kernels.cu
284
csrc/kernels.cu
|
@ -43,6 +43,79 @@ __device__ float atomicMin(float* address, float val) {
|
|||
return __int_as_float(old);
|
||||
}
|
||||
|
||||
__device__ float dDequantizeFP4(unsigned char val, float absmax)
|
||||
{
|
||||
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
|
||||
if((val & 0b0110) == 0)
|
||||
{
|
||||
// subnormal
|
||||
if((val & 0b0001) == 0)
|
||||
return 0.0f;
|
||||
else
|
||||
return sign*0.0625f*absmax;
|
||||
}
|
||||
else
|
||||
{
|
||||
// normal
|
||||
float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
|
||||
float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;
|
||||
|
||||
return sign*exponent*fraction*absmax;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ unsigned char dQuantizeFP4(float x)
|
||||
{
|
||||
// FP4 with bias of 3
|
||||
// first bit is a sign
|
||||
// subnormals
|
||||
// 0b000 = 0
|
||||
// 0b001 = 0.0625
|
||||
// 0b110 = 2
|
||||
// 0b111 = 3
|
||||
// 0b100 = 4
|
||||
// 0b101 = 6
|
||||
// 0b010 = 8
|
||||
// 0b011 = 12
|
||||
|
||||
int sign = x < 0 ? 0b1000 : 0b0000;
|
||||
x = fabsf(x);
|
||||
if(x > 3.5f)
|
||||
{
|
||||
if( x > 7.0f)
|
||||
{
|
||||
if( x > 10.0f)
|
||||
return 0b0011+sign;
|
||||
else
|
||||
return 0b0010+sign;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x > 5.0f)
|
||||
return 0b101+sign;
|
||||
else
|
||||
return 0b100+sign;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x > 1.03125f)
|
||||
{
|
||||
if(x > 2.5f)
|
||||
return 0b0111+sign;
|
||||
else
|
||||
return 0b0110+sign;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x > 0.03125f)
|
||||
return 0b0001+sign;
|
||||
else
|
||||
return 0b0000+sign;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int STOCHASTIC>
|
||||
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
|
||||
{
|
||||
|
@ -427,7 +500,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int FP4>
|
||||
//__launch_bounds__(TH, 4)
|
||||
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
|
||||
{
|
||||
|
@ -437,13 +510,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
|
||||
T vals[NUM_PER_TH];
|
||||
float rand_vals[NUM_PER_TH];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
unsigned char qvals[FP4 ? NUM_PER_TH/2 : NUM_PER_TH];
|
||||
//float local_abs_max = -FLT_MAX;
|
||||
float local_abs_max = 0.0f;
|
||||
int local_rand_idx = 0;
|
||||
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, FP4 ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
|
||||
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
|
||||
|
@ -454,8 +527,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
__shared__ float smem_code[256];
|
||||
__shared__ float smem_absmax_value[1];
|
||||
|
||||
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
|
||||
smem_code[i] = code[i];
|
||||
if(!FP4)
|
||||
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
|
||||
smem_code[i] = code[i];
|
||||
|
||||
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
|
||||
{
|
||||
|
@ -495,61 +569,138 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
|
||||
}
|
||||
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
if(FP4)
|
||||
{
|
||||
if(!STOCHASTIC)
|
||||
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
|
||||
else
|
||||
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH/2; j++)
|
||||
{
|
||||
unsigned char packed_fp4 = 0;
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4;
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f);
|
||||
qvals[j] = packed_fp4;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
{
|
||||
if(!STOCHASTIC)
|
||||
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
|
||||
else
|
||||
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
StoreChar(storec).Store(&(out[i]), qvals, valid_items);
|
||||
StoreChar(storec).Store(&(out[FP4 ? i/2 : i]), qvals, FP4 ? (valid_items+1)/2 : valid_items);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
|
||||
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int FP4>
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
|
||||
{
|
||||
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
const int n_load = (gridDim.x * TILE_SIZE);
|
||||
int valid_items_load = 0;
|
||||
int valid_items_store = 0;
|
||||
const int base_idx = (blockIdx.x * TILE_SIZE);
|
||||
|
||||
T vals[NUM_PER_TH];
|
||||
T vals[NUM_PER_TH*(FP4 ? 2 : 1)];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
float local_abs_max = -FLT_MAX;
|
||||
|
||||
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef cub::BlockStore<T, THREADS, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
typedef cub::BlockStore<T, THREADS, NUM_PER_TH*(FP4 ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
|
||||
__shared__ typename LoadChar::TempStorage loadchar;
|
||||
__shared__ typename StoreT::TempStorage storet;
|
||||
//__shared__ float smem_code[256];
|
||||
//float local_code[16];
|
||||
|
||||
//if(threadIdx.x < 256)
|
||||
//smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
|
||||
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
|
||||
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
|
||||
{
|
||||
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
|
||||
local_abs_max = absmax[i/BLOCK_SIZE];
|
||||
if(FP4)
|
||||
{
|
||||
valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i;
|
||||
valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2;
|
||||
}
|
||||
else
|
||||
{
|
||||
valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i;
|
||||
valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i;
|
||||
}
|
||||
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]);
|
||||
|
||||
__syncthreads();
|
||||
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
|
||||
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
|
||||
|
||||
// load code through read-only cache via __ldg
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
|
||||
|
||||
if(FP4)
|
||||
{
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
{
|
||||
vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
|
||||
vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// load code through read-only cache via __ldg
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
StoreT(storet).Store(&(out[i]), vals, valid_items);
|
||||
StoreT(storet).Store(&(out[FP4 ? i*2 : i]), vals, valid_items_store);
|
||||
}
|
||||
}
|
||||
|
||||
//template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int TILE_SIZE>
|
||||
//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store)
|
||||
//{
|
||||
//
|
||||
// const int n_load = n_store/2;
|
||||
// const int base_idx = (blockIdx.x * TILE_SIZE);
|
||||
//
|
||||
// T vals[NUM_PER_TH*2];
|
||||
// unsigned char qvals[NUM_PER_TH];
|
||||
//
|
||||
// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE;
|
||||
// int idx = base_idx + (threadIdx.x*NUM_PER_TH);
|
||||
//
|
||||
// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]);
|
||||
//
|
||||
// if(valid_items == TILE_SIZE)
|
||||
// {
|
||||
// // we do 64 byte loads so we can 128 byte stores
|
||||
// reinterpret_cast<int2(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int2*>(A)[idx/8];
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// #pragma unroll
|
||||
// for(int j = 0; j < NUM_PER_TH; j++)
|
||||
// if(idx+j < n_load)
|
||||
// qvals[j] = A[idx+j];
|
||||
// else
|
||||
// qvals[j] = 0;
|
||||
// }
|
||||
//
|
||||
//
|
||||
// #pragma unroll NUM_PER_TH
|
||||
// for(int j = 0; j < NUM_PER_TH; j++)
|
||||
// {
|
||||
// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f);
|
||||
// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f);
|
||||
// }
|
||||
//
|
||||
//
|
||||
// reinterpret_cast<int4(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int4*>(A)[idx/8];
|
||||
// reinterpret_cast<int4*>(A)[idx/16] = reinterpret_cast<int4(&)[16]>(local_valC)[j/num_items];
|
||||
//
|
||||
//
|
||||
//}
|
||||
|
||||
|
||||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
|
||||
{
|
||||
|
@ -2523,7 +2674,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
|
|||
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
|
||||
int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
|
||||
if(idx >= colsB){ break; }
|
||||
//printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
|
||||
if((idx+num_items < colsB))
|
||||
{
|
||||
if(BITS == 8)
|
||||
|
@ -2543,8 +2693,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
|
|||
#pragma unroll num_items
|
||||
for(int k = 0; k < num_items; k++)
|
||||
{
|
||||
//if((float)local_valsB[k] != 0.0)
|
||||
// printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
|
||||
if(BITS == 8 && dequant_stats != NULL)
|
||||
// we do texture cache reads (__ldg) on dequant_stats which should be super fast
|
||||
{
|
||||
|
@ -2789,38 +2937,42 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
|
|||
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 256, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 256, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 256, 128, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 256, 128, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 128, 64, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 128, 64, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 64, 64, 1>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 64, 64, 1>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 256, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 256, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 0>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 0>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
|
||||
|
||||
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
||||
|
|
|
@ -14,8 +14,8 @@ template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A,
|
|||
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
|
||||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int FP4> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int FP4> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
|
|
54
csrc/ops.cu
54
csrc/ops.cu
|
@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
|
@ -58,42 +58,34 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
|
|||
assert(blocksize == 4096);
|
||||
|
||||
if(blocksize == 4096)
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 2048)
|
||||
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 2048, 4, 0, FP4><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 1024)
|
||||
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 1024, 4, 0, FP4><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 512)
|
||||
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 512, 2, 0, FP4><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 256)
|
||||
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 256, 2, 0, FP4><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 128)
|
||||
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 128, 2, 0, FP4><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 64)
|
||||
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 64, 2, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
template<typename T, int FP4> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
if(blocksize == 4096)
|
||||
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 2048)
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 1024)
|
||||
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 512)
|
||||
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 256)
|
||||
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 128)
|
||||
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 64)
|
||||
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
|
||||
int tile_size = FP4 ? 1024 : 512;
|
||||
|
||||
if(FP4)
|
||||
kDequantizeBlockwise<T, 512, 64, 8, FP4><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
|
||||
else
|
||||
kDequantizeBlockwise<T, 512, 64, 8, FP4><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
@ -688,12 +680,16 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
|
|||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, 0>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, 0>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, 1>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, 1>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
|
||||
#define MAKE_optimizer32bit(name, gtype) \
|
||||
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||
|
|
|
@ -128,8 +128,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
|
|||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n);
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T, int FP4> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
|
|
|
@ -75,13 +75,17 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
|||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1, 0>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1, 0>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, 1>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, 1>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, 0>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, 0>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, 1>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, 1>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
|
@ -148,6 +152,11 @@ extern "C"
|
|||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
|
|
|
@ -152,7 +152,7 @@ def test_dynamic_quantization():
|
|||
|
||||
def test_dynamic_blockwise_quantization():
|
||||
#print('')
|
||||
for blocksize in [4096, 2048, 1024, 512]:
|
||||
for blocksize in [4096, 2048, 1024, 512, 256, 128, 64]:
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
|
@ -2189,7 +2189,88 @@ def test_bench_dequantization():
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
F.dequantize_blockwise(qa, SA, blocksize=2048)
|
||||
#F.dequantize_blockwise(qa, SA, blocksize=2048)
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/1e6)
|
||||
|
||||
|
||||
|
||||
def test_fp4_quant():
|
||||
vals = list(product([0, 1], repeat=4))
|
||||
|
||||
code = {}
|
||||
for bits in vals:
|
||||
result = 0
|
||||
bias = 3
|
||||
sign, e1, e2, p1 = bits
|
||||
idx = sign*8 + e1*4 + e2*2 + p1*1
|
||||
sign = -1.0 if sign else 1.0
|
||||
exp = e1*2 + e2*1
|
||||
if exp == 0:
|
||||
# sub-normal
|
||||
if p1 == 0: result = 0
|
||||
else: result = sign*0.0625
|
||||
else:
|
||||
# normal
|
||||
exp = 2**(-exp + bias + 1)
|
||||
frac = 1.5 if p1 else 1.0
|
||||
result = sign*exp*frac
|
||||
code[idx] = result
|
||||
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
||||
A2 = F.dequantize_fp4(qa, SA)
|
||||
#qa, SA = F.quantize_fp4(A1, blocksize=128)
|
||||
#A2 = F.dequantize_fp4(qa, SA, blocksize=128)
|
||||
|
||||
#A1 = A1.flatten().sort()[0]
|
||||
#A2 = A2.flatten().sort()[0]
|
||||
|
||||
#print(A1)
|
||||
#print(A2)
|
||||
|
||||
err = (A1 - A2).abs().float()
|
||||
relerr = (err/A1.abs().float()).mean()
|
||||
err = err.mean()
|
||||
|
||||
print(err, relerr)
|
||||
|
||||
|
||||
|
||||
|
||||
#assert err.item() < 0.1
|
||||
#assert relerr.item() < 0.28
|
||||
|
||||
|
||||
def test_bench_fp4_dequant():
|
||||
blocksize = 256
|
||||
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
||||
qa, SA = F.quantize_fp4(a, blocksize=blocksize)
|
||||
|
||||
input_size = a.numel()/2
|
||||
output_size = a.numel()*2
|
||||
num_bytes = input_size+output_size
|
||||
GB = num_bytes/1e9
|
||||
max_theoretical_s = GB/768
|
||||
print(max_theoretical_s*1e6)
|
||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||
|
||||
iters = 5
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
F.dequantize_fp4(qa, SA, blocksize=blocksize)
|
||||
#b.copy_(a)
|
||||
torch.cuda.synchronize()
|
||||
print((time.time()-t0)/iters*1e6)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
torch.matmul(b, a.t())
|
||||
torch.cuda.synchronize()
|
||||
print((time.time()-t0)/iters*1e6)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user