First draft of NF4.
This commit is contained in:
parent
4ad999d144
commit
64cc05920d
|
@ -688,8 +688,13 @@ def dequantize_blockwise(
|
|||
|
||||
return out
|
||||
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
|
||||
return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'fp4')
|
||||
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False) -> Tensor:
|
||||
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
|
||||
return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'nf4')
|
||||
|
||||
def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of FP4 values.
|
||||
|
||||
|
@ -705,6 +710,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
|||
The output tensor (8-bit).
|
||||
blocksize : int
|
||||
The blocksize used in quantization.
|
||||
quant_type : str
|
||||
The 4-bit quantization data type {fp4, nf4}
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -715,6 +722,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
|||
"""
|
||||
if A.device.type != 'cuda':
|
||||
raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
|
||||
if quant_type not in ['fp4', 'nf4']:
|
||||
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
|
||||
|
||||
n = A.numel()
|
||||
input_shape = A.shape
|
||||
|
@ -734,9 +743,15 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
|||
is_on_gpu([A, out, absmax])
|
||||
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
if quant_type == 'fp4':
|
||||
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
if quant_type == 'fp4':
|
||||
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
@ -754,8 +769,13 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
|||
|
||||
return out, state
|
||||
|
||||
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
|
||||
return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'fp4')
|
||||
|
||||
def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
|
||||
def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
|
||||
return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'nf4')
|
||||
|
||||
def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
|
||||
"""
|
||||
Dequantizes FP4 blockwise quantized values.
|
||||
|
||||
|
@ -771,6 +791,10 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
The absmax values.
|
||||
out : torch.Tensor
|
||||
Dequantized output tensor.
|
||||
blocksize : int
|
||||
The blocksize used in quantization.
|
||||
quant_type : str
|
||||
The 4-bit quantization data type {fp4, nf4}
|
||||
|
||||
|
||||
Returns
|
||||
|
@ -780,6 +804,8 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
"""
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
if quant_type not in ['fp4', 'nf4']:
|
||||
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
|
||||
|
||||
if quant_state is None:
|
||||
assert absmax is not None and out is not None
|
||||
|
@ -802,9 +828,15 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
|||
device = pre_call(A.device)
|
||||
is_on_gpu([A, absmax, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
if quant_type == 'fp4':
|
||||
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
if quant_type == 'fp4':
|
||||
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
|
271
csrc/kernels.cu
271
csrc/kernels.cu
|
@ -140,44 +140,111 @@ __device__ unsigned char dQuantizeFP4(float x)
|
|||
return 0b0000+sign;
|
||||
}
|
||||
|
||||
__device__ float dDequantizeNF4(unsigned char val, float absmax)
|
||||
{
|
||||
// the values for this tree was generated by test_normal_map_tree
|
||||
// in the file tests/test_functional.py
|
||||
if((val & 0b1000) == 8)
|
||||
if((val & 0b0100) == 4) // 1
|
||||
if((val & 0b0010) == 2) // 11
|
||||
if((val & 0b0001) == 1) // 111
|
||||
return 1.0f*absmax;
|
||||
else
|
||||
return 0.7229568362236023f*absmax;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 110
|
||||
return 0.5626170039176941f*absmax;
|
||||
else
|
||||
return 0.44070982933044434f*absmax;
|
||||
else
|
||||
if((val & 0b0010) == 2) //10
|
||||
if((val & 0b0001) == 1) // 101
|
||||
return 0.33791524171829224f*absmax;
|
||||
else
|
||||
return 0.24611230194568634f*absmax;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 100
|
||||
return 0.16093020141124725f*absmax;
|
||||
else
|
||||
return 0.07958029955625534f*absmax;
|
||||
|
||||
else
|
||||
if((val & 0b0100) == 4) // 0
|
||||
if((val & 0b0010) == 2) //01
|
||||
if((val & 0b0001) == 1) // 011
|
||||
return 0.0f*absmax;
|
||||
else
|
||||
return -0.09105003625154495f*absmax;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 010
|
||||
return -0.18477343022823334f*absmax;
|
||||
else
|
||||
return -0.28444138169288635f*absmax;
|
||||
else
|
||||
if((val & 0b0010) == 2) //00
|
||||
if((val & 0b0001) == 1) // 001
|
||||
return -0.39491748809814453f*absmax;
|
||||
else
|
||||
return -0.5250730514526367f*absmax;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 000
|
||||
return -0.6961928009986877f*absmax;
|
||||
else
|
||||
return -1.0f*absmax;
|
||||
|
||||
}
|
||||
|
||||
__device__ unsigned char dQuantizeNormal(float x)
|
||||
{
|
||||
// FP4 with bias of 3
|
||||
// first bit is a sign
|
||||
// subnormals
|
||||
// 0b000 = 0
|
||||
// 0b001 = 0.0625
|
||||
// 0b110 = 2
|
||||
// 0b111 = 3
|
||||
// 0b100 = 4
|
||||
// 0b101 = 6
|
||||
// 0b010 = 8
|
||||
// 0b011 = 12
|
||||
|
||||
int sign = x < 0 ? 0b1000 : 0b0000;
|
||||
x = fabsf(x);
|
||||
if(x > 3.5f)
|
||||
if( x > 7.0f)
|
||||
if( x > 10.0f)
|
||||
return 0b0011+sign;
|
||||
// the values for this tree was generated by test_normal_map_tree
|
||||
// in the file tests/test_functional.py
|
||||
if(x > 0.03979014977812767f)
|
||||
if(x > 0.3893125355243683f) // 1
|
||||
if(x > 0.6427869200706482f) // 11
|
||||
if(x > 0.8614784181118011f) // 111
|
||||
return 0b1111;
|
||||
else
|
||||
return 0b1110;
|
||||
else
|
||||
return 0b0010+sign;
|
||||
if(x > 0.5016634166240692f) // 110
|
||||
return 0b1101;
|
||||
else
|
||||
return 0b1100;
|
||||
else
|
||||
if(x > 5.0f)
|
||||
return 0b101+sign;
|
||||
if(x > 0.2035212516784668f) // 10
|
||||
if(x > 0.2920137718319893f) // 101
|
||||
return 0b1011;
|
||||
else
|
||||
return 0b1010;
|
||||
else
|
||||
return 0b100+sign;
|
||||
if(x > 0.1202552504837513f) // 100
|
||||
return 0b1001;
|
||||
else
|
||||
return 0b1100;
|
||||
else
|
||||
if(x > 1.03125f)
|
||||
if(x > 2.5f)
|
||||
return 0b0111+sign;
|
||||
if(x > -0.33967943489551544f) // 0
|
||||
if(x > -0.13791173323988914f) // 01
|
||||
if(x > -0.045525018125772476f) // 011
|
||||
return 0b0111;
|
||||
else
|
||||
return 0b0110;
|
||||
else
|
||||
return 0b0110+sign;
|
||||
if(x > -0.23460740596055984f) // 010
|
||||
return 0b0101;
|
||||
else
|
||||
return 0b0100;
|
||||
else
|
||||
if(x > 0.03125f)
|
||||
return 0b0001+sign;
|
||||
if(x > -0.6106329262256622f) // 00
|
||||
if(x > -0.4599952697753906f) // 001
|
||||
return 0b0011;
|
||||
else
|
||||
return 0b0010;
|
||||
else
|
||||
return 0b0000+sign;
|
||||
if(x > -0.8480964004993439f) // 000
|
||||
return 0b0001;
|
||||
else
|
||||
return 0b0000;
|
||||
}
|
||||
|
||||
template <int STOCHASTIC>
|
||||
|
@ -564,7 +631,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int FP4>
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
|
||||
//__launch_bounds__(TH, 4)
|
||||
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
|
||||
{
|
||||
|
@ -574,13 +641,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
|
||||
T vals[NUM_PER_TH];
|
||||
float rand_vals[NUM_PER_TH];
|
||||
unsigned char qvals[FP4 ? NUM_PER_TH/2 : NUM_PER_TH];
|
||||
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
|
||||
//float local_abs_max = -FLT_MAX;
|
||||
float local_abs_max = 0.0f;
|
||||
int local_rand_idx = 0;
|
||||
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, FP4 ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
|
||||
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
|
||||
|
@ -591,7 +658,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
__shared__ float smem_code[256];
|
||||
__shared__ float smem_absmax_value[1];
|
||||
|
||||
if(!FP4)
|
||||
if(DATA_TYPE == General8bit)
|
||||
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
|
||||
smem_code[i] = code[i];
|
||||
|
||||
|
@ -633,31 +700,41 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
|
||||
}
|
||||
|
||||
if(FP4)
|
||||
unsigned char packed_4bit = 0;
|
||||
switch(DATA_TYPE)
|
||||
{
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH/2; j++)
|
||||
{
|
||||
unsigned char packed_fp4 = 0;
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
|
||||
qvals[j] = packed_fp4;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
{
|
||||
if(!STOCHASTIC)
|
||||
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
|
||||
else
|
||||
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
|
||||
}
|
||||
case General8bit:
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
{
|
||||
if(!STOCHASTIC)
|
||||
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
|
||||
else
|
||||
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
|
||||
}
|
||||
break;
|
||||
case FP4:
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH/2; j++)
|
||||
{
|
||||
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
|
||||
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
|
||||
qvals[j] = packed_4bit;
|
||||
}
|
||||
break;
|
||||
case NF4:
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH/2; j++)
|
||||
{
|
||||
packed_4bit |= dQuantizeNormal(((float)vals[2*j])*local_abs_max) << 4;
|
||||
packed_4bit |= dQuantizeNormal(((float)vals[2*j+1])*local_abs_max);
|
||||
qvals[j] = packed_4bit;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
StoreChar(storec).Store(&(out[FP4 ? i/2 : i]), qvals, FP4 ? (valid_items+1)/2 : valid_items);
|
||||
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2957,44 +3034,60 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
|
|||
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 256, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 256, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 32, 1, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 32, 1, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
|
||||
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
|
||||
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 256, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 256, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
|
||||
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
|
||||
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 0>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 0>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
|
||||
|
||||
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
||||
|
|
|
@ -14,8 +14,8 @@ template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A,
|
|||
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
|
||||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int FP4> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int FP4> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
|
|
50
csrc/ops.cu
50
csrc/ops.cu
|
@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
|
@ -60,34 +60,32 @@ template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * co
|
|||
if(blocksize == 4096)
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 2048)
|
||||
kQuantizeBlockwise<T, 2048, 4, 0, FP4><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 1024)
|
||||
kQuantizeBlockwise<T, 1024, 4, 0, FP4><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 512)
|
||||
kQuantizeBlockwise<T, 512, 2, 0, FP4><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 256)
|
||||
kQuantizeBlockwise<T, 256, 2, 0, FP4><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 128)
|
||||
kQuantizeBlockwise<T, 128, 2, 0, FP4><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 64)
|
||||
kQuantizeBlockwise<T, 64, 2, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 32 and FP4 == 0)
|
||||
kQuantizeBlockwise<T, 32, 1, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T, int FP4> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
int tile_size = FP4 ? 1024 : 512;
|
||||
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
|
||||
|
||||
if(FP4)
|
||||
kDequantizeBlockwise<T, 512, 64, 8, FP4><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
|
||||
if(DATA_TYPE > 0)
|
||||
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
|
||||
else
|
||||
kDequantizeBlockwise<T, 512, 64, 8, FP4><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
|
||||
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
@ -682,16 +680,20 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
|
|||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 0, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, 0>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, 0>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, 1>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, 1>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
|
||||
#define MAKE_optimizer32bit(name, gtype) \
|
||||
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||
|
|
11
csrc/ops.cuh
11
csrc/ops.cuh
|
@ -81,6 +81,13 @@ typedef enum Transform_t
|
|||
COL_AMPERE = 4,
|
||||
} Transform_t;
|
||||
|
||||
typedef enum DataType_t
|
||||
{
|
||||
General8bit = 0,
|
||||
FP4 = 1,
|
||||
NF4 = 2,
|
||||
} DataType_t;
|
||||
|
||||
class Context
|
||||
{
|
||||
public:
|
||||
|
@ -128,8 +135,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
|
|||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n);
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n);
|
||||
template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T, int FP4> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
|
|
|
@ -76,17 +76,21 @@ MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
|||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1, 0>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1, 0>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, 1>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, 1>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1, General8bit>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1, General8bit>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, 0>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, 0>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, 1>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, 1>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
|
@ -157,6 +161,10 @@ extern "C"
|
|||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
|
|
|
@ -2254,16 +2254,18 @@ def test_fp4_quant():
|
|||
assert relerr.item() < 0.28
|
||||
|
||||
|
||||
def test_fp4_compressed_stats():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
||||
def test_4bit_compressed_stats(quant_type):
|
||||
for blocksize in [128, 64]:
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
for i in range(10000):
|
||||
for i in range(10):
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize)
|
||||
q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True)
|
||||
A2 = F.dequantize_fp4(q2, SA2)
|
||||
A3 = F.dequantize_fp4(q3, SA3)
|
||||
q2, SA2 = F.quantize_4bit_packed(A1, blocksize=blocksize, quant_type=quant_type)
|
||||
q3, SA3= F.quantize_4bit_packed(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
|
||||
A2 = F.dequantize_4bit_packed(q2, SA2, quant_type=quant_type)
|
||||
A3 = F.dequantize_4bit_packed(q3, SA3, quant_type=quant_type)
|
||||
|
||||
|
||||
err = (A1 - A2).abs().float()
|
||||
|
@ -2290,10 +2292,12 @@ def test_fp4_compressed_stats():
|
|||
|
||||
|
||||
|
||||
def test_bench_fp4_dequant():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
||||
def test_bench_fp4_dequant(quant_type):
|
||||
blocksize = 256
|
||||
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
||||
qa, SA = F.quantize_fp4(a, blocksize=blocksize)
|
||||
qa, SA = F.quantize_4bit_packed(a, blocksize=blocksize, quant_type=quant_type)
|
||||
|
||||
input_size = a.numel()/2
|
||||
output_size = a.numel()*2
|
||||
|
@ -2307,7 +2311,7 @@ def test_bench_fp4_dequant():
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
F.dequantize_fp4(qa, SA, blocksize=blocksize)
|
||||
F.dequantize_4bit_packed(qa, SA, blocksize=blocksize, quant_type=quant_type)
|
||||
#b.copy_(a)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/iters*1e6)
|
||||
|
@ -2325,6 +2329,7 @@ def test_normal_map_tree():
|
|||
code = F.create_normal_map()
|
||||
values =code[:8].tolist() + code[-8:].tolist()
|
||||
num_pivots = 1
|
||||
print(values)
|
||||
while num_pivots <16:
|
||||
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
|
||||
print(idx)
|
||||
|
|
Loading…
Reference in New Issue
Block a user