First draft of NF4.

This commit is contained in:
Tim Dettmers 2023-04-02 16:10:35 -07:00
parent 4ad999d144
commit 64cc05920d
7 changed files with 289 additions and 142 deletions

View File

@ -688,8 +688,13 @@ def dequantize_blockwise(
return out
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'fp4')
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False) -> Tensor:
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
return quantize_4bit_packed(A, absmax, out, blocksize, compress_statistics, 'nf4')
def quantize_4bit_packed(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor:
"""
Quantize tensor A in blocks of FP4 values.
@ -705,6 +710,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
The output tensor (8-bit).
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
@ -715,6 +722,8 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
"""
if A.device.type != 'cuda':
raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
if quant_type not in ['fp4', 'nf4']:
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
n = A.numel()
input_shape = A.shape
@ -734,9 +743,15 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
is_on_gpu([A, out, absmax])
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
if quant_type == 'fp4':
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
else:
lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
if quant_type == 'fp4':
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
else:
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
@ -754,8 +769,13 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
return out, state
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'fp4')
def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
return dequantize_4bit_packed(A, quant_state, absmax, out, blocksize, 'nf4')
def dequantize_4bit_packed(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.
@ -771,6 +791,10 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
@ -780,6 +804,8 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
"""
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
if quant_type not in ['fp4', 'nf4']:
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
if quant_state is None:
assert absmax is not None and out is not None
@ -802,9 +828,15 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
device = pre_call(A.device)
is_on_gpu([A, absmax, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
if quant_type == 'fp4':
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
else:
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
if quant_type == 'fp4':
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
else:
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)

View File

@ -140,44 +140,111 @@ __device__ unsigned char dQuantizeFP4(float x)
return 0b0000+sign;
}
__device__ float dDequantizeNF4(unsigned char val, float absmax)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f*absmax;
else
return 0.7229568362236023f*absmax;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f*absmax;
else
return 0.44070982933044434f*absmax;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f*absmax;
else
return 0.24611230194568634f*absmax;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f*absmax;
else
return 0.07958029955625534f*absmax;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f*absmax;
else
return -0.09105003625154495f*absmax;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f*absmax;
else
return -0.28444138169288635f*absmax;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f*absmax;
else
return -0.5250730514526367f*absmax;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f*absmax;
else
return -1.0f*absmax;
}
__device__ unsigned char dQuantizeNormal(float x)
{
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12
int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if(x > 3.5f)
if( x > 7.0f)
if( x > 10.0f)
return 0b0011+sign;
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if(x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
else
return 0b0010+sign;
if(x > 0.5016634166240692f) // 110
return 0b1101;
else
return 0b1100;
else
if(x > 5.0f)
return 0b101+sign;
if(x > 0.2035212516784668f) // 10
if(x > 0.2920137718319893f) // 101
return 0b1011;
else
return 0b1010;
else
return 0b100+sign;
if(x > 0.1202552504837513f) // 100
return 0b1001;
else
return 0b1100;
else
if(x > 1.03125f)
if(x > 2.5f)
return 0b0111+sign;
if(x > -0.33967943489551544f) // 0
if(x > -0.13791173323988914f) // 01
if(x > -0.045525018125772476f) // 011
return 0b0111;
else
return 0b0110;
else
return 0b0110+sign;
if(x > -0.23460740596055984f) // 010
return 0b0101;
else
return 0b0100;
else
if(x > 0.03125f)
return 0b0001+sign;
if(x > -0.6106329262256622f) // 00
if(x > -0.4599952697753906f) // 001
return 0b0011;
else
return 0b0010;
else
return 0b0000+sign;
if(x > -0.8480964004993439f) // 000
return 0b0001;
else
return 0b0000;
}
template <int STOCHASTIC>
@ -564,7 +631,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
}
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int FP4>
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
@ -574,13 +641,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[FP4 ? NUM_PER_TH/2 : NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, FP4 ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
@ -591,7 +658,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
__shared__ float smem_code[256];
__shared__ float smem_absmax_value[1];
if(!FP4)
if(DATA_TYPE == General8bit)
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
smem_code[i] = code[i];
@ -633,31 +700,41 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
if(FP4)
unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
unsigned char packed_fp4 = 0;
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_fp4;
}
}
else
{
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
else
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
}
case General8bit:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
else
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
}
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeNormal(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNormal(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
}
__syncthreads();
StoreChar(storec).Store(&(out[FP4 ? i/2 : i]), qvals, FP4 ? (valid_items+1)/2 : valid_items);
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
}
}
@ -2957,44 +3034,60 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 512, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 512, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 256, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 256, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 32, 1, 0, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 32, 1, 0, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
template __global__ void kQuantizeBlockwise<half, 4096, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 512, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 512, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 256, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 256, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 128, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 64, 2, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 64, 2, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 1>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, 0>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, 0>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \

View File

@ -14,8 +14,8 @@ template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A,
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int FP4> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int FP4> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,

View File

@ -50,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
@ -60,34 +60,32 @@ template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * co
if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0, FP4><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, 0, FP4><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, 0, FP4><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256)
kQuantizeBlockwise<T, 256, 2, 0, FP4><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
kQuantizeBlockwise<T, 128, 2, 0, FP4><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 64)
kQuantizeBlockwise<T, 64, 2, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 32 and FP4 == 0)
kQuantizeBlockwise<T, 32, 1, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T, int FP4> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
int tile_size = FP4 ? 1024 : 512;
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
if(FP4)
kDequantizeBlockwise<T, 512, 64, 8, FP4><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
if(DATA_TYPE > 0)
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
else
kDequantizeBlockwise<T, 512, 64, 8, FP4><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@ -682,16 +680,20 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 0, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 1, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half, 0>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, 0>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<half, 1>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, 1>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \

View File

@ -81,6 +81,13 @@ typedef enum Transform_t
COL_AMPERE = 4,
} Transform_t;
typedef enum DataType_t
{
General8bit = 0,
FP4 = 1,
NF4 = 2,
} DataType_t;
class Context
{
public:
@ -128,8 +135,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n);
template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T, int FP4> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,

View File

@ -76,17 +76,21 @@ MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1, 0>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1, 0>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, 1>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, 1>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1, General8bit>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1, General8bit>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, 0>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, 0>(code, A, absmax, out, blocksize, n); }
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, 1>(NULL, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, 1>(NULL, A, absmax, out, blocksize, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
@ -157,6 +161,10 @@ extern "C"
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
#define MAKE_CFUNC32(name, gtype, gbits) \
void c##name##32bit_g##gbits(gtype *g, gtype *p, \

View File

@ -2254,16 +2254,18 @@ def test_fp4_quant():
assert relerr.item() < 0.28
def test_fp4_compressed_stats():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_4bit_compressed_stats(quant_type):
for blocksize in [128, 64]:
errs1 = []
errs2 = []
for i in range(10000):
for i in range(10):
A1 = torch.randn(1024, 1024, device='cuda').half()
q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize)
q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True)
A2 = F.dequantize_fp4(q2, SA2)
A3 = F.dequantize_fp4(q3, SA3)
q2, SA2 = F.quantize_4bit_packed(A1, blocksize=blocksize, quant_type=quant_type)
q3, SA3= F.quantize_4bit_packed(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
A2 = F.dequantize_4bit_packed(q2, SA2, quant_type=quant_type)
A3 = F.dequantize_4bit_packed(q3, SA3, quant_type=quant_type)
err = (A1 - A2).abs().float()
@ -2290,10 +2292,12 @@ def test_fp4_compressed_stats():
def test_bench_fp4_dequant():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
def test_bench_fp4_dequant(quant_type):
blocksize = 256
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
qa, SA = F.quantize_fp4(a, blocksize=blocksize)
qa, SA = F.quantize_4bit_packed(a, blocksize=blocksize, quant_type=quant_type)
input_size = a.numel()/2
output_size = a.numel()*2
@ -2307,7 +2311,7 @@ def test_bench_fp4_dequant():
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
F.dequantize_fp4(qa, SA, blocksize=blocksize)
F.dequantize_4bit_packed(qa, SA, blocksize=blocksize, quant_type=quant_type)
#b.copy_(a)
torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)
@ -2325,6 +2329,7 @@ def test_normal_map_tree():
code = F.create_normal_map()
values =code[:8].tolist() + code[-8:].tolist()
num_pivots = 1
print(values)
while num_pivots <16:
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
print(idx)