Generalized FP4 data type.
This commit is contained in:
parent
51a21df728
commit
2dd5d69056
144
csrc/kernels.cu
144
csrc/kernels.cu
|
@ -64,6 +64,33 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
|
|||
}
|
||||
}
|
||||
|
||||
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
|
||||
{
|
||||
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
|
||||
if((val & 0b0100) == 4) // 0
|
||||
if((val & 0b0010) == 2) //01
|
||||
if((val & 0b0001) == 1) // 111
|
||||
return 0.25000000f*absmax*sign; // 1111
|
||||
else
|
||||
return 0.16666667f*absmax*sign; // 1110
|
||||
else
|
||||
if((val & 0b0001) == 1) // 110
|
||||
return 0.50000000f*absmax*sign; // 1101
|
||||
else
|
||||
return 0.33333333f*absmax*sign; // 1100
|
||||
else
|
||||
if((val & 0b0010) == 2) //10
|
||||
if((val & 0b0001) == 1) // 101
|
||||
return 1.00000000f*absmax*sign; // 1011
|
||||
else
|
||||
return 0.66666667f*absmax*sign; // 1010
|
||||
else
|
||||
if((val & 0b0001) == 1) // 100
|
||||
return 5.208333333e-03f*absmax*sign; // 1001
|
||||
else
|
||||
return 0.00000000f*absmax*sign; // 1000
|
||||
}
|
||||
|
||||
__device__ unsigned char dQuantizeFP4(float x)
|
||||
{
|
||||
// FP4 with bias of 3
|
||||
|
@ -78,42 +105,79 @@ __device__ unsigned char dQuantizeFP4(float x)
|
|||
// 0b010 = 8
|
||||
// 0b011 = 12
|
||||
|
||||
|
||||
// we do a binary search
|
||||
// the pivots are divided by 12 (the FP4 absmax)
|
||||
// since we assum input data is in [-1.0, 1.0]
|
||||
|
||||
// !be careful here, its easy to make a mistake
|
||||
// that is difficult to noice if you add an extra
|
||||
// zero somewhere!
|
||||
|
||||
int sign = x < 0 ? 0b1000 : 0b0000;
|
||||
x = fabsf(x);
|
||||
if(x > 0.29166667f)
|
||||
if( x > 0.583333f)
|
||||
if( x > 0.8333333f)
|
||||
return 0b0011+sign;
|
||||
else
|
||||
return 0b0010+sign;
|
||||
else
|
||||
if(x > 0.4166667f)
|
||||
return 0b101+sign;
|
||||
else
|
||||
return 0b100+sign;
|
||||
else
|
||||
if(x > 0.0859375f)
|
||||
if(x > 0.20833333f)
|
||||
return 0b0111+sign;
|
||||
else
|
||||
return 0b0110+sign;
|
||||
else
|
||||
if(x > 0.00260417f)
|
||||
return 0b0001+sign;
|
||||
else
|
||||
return 0b0000+sign;
|
||||
}
|
||||
|
||||
__device__ unsigned char dQuantizeNormal(float x)
|
||||
{
|
||||
// FP4 with bias of 3
|
||||
// first bit is a sign
|
||||
// subnormals
|
||||
// 0b000 = 0
|
||||
// 0b001 = 0.0625
|
||||
// 0b110 = 2
|
||||
// 0b111 = 3
|
||||
// 0b100 = 4
|
||||
// 0b101 = 6
|
||||
// 0b010 = 8
|
||||
// 0b011 = 12
|
||||
|
||||
int sign = x < 0 ? 0b1000 : 0b0000;
|
||||
x = fabsf(x);
|
||||
if(x > 3.5f)
|
||||
{
|
||||
if( x > 7.0f)
|
||||
{
|
||||
if( x > 10.0f)
|
||||
return 0b0011+sign;
|
||||
else
|
||||
return 0b0010+sign;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x > 5.0f)
|
||||
return 0b101+sign;
|
||||
else
|
||||
return 0b100+sign;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x > 1.03125f)
|
||||
{
|
||||
if(x > 2.5f)
|
||||
return 0b0111+sign;
|
||||
else
|
||||
return 0b0110+sign;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(x > 0.03125f)
|
||||
return 0b0001+sign;
|
||||
else
|
||||
return 0b0000+sign;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int STOCHASTIC>
|
||||
|
@ -575,8 +639,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
for(int j = 0; j < NUM_PER_TH/2; j++)
|
||||
{
|
||||
unsigned char packed_fp4 = 0;
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4;
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f);
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
|
||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
|
||||
qvals[j] = packed_fp4;
|
||||
}
|
||||
}
|
||||
|
@ -639,8 +703,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
|||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
{
|
||||
vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
|
||||
vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
|
||||
//vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
|
||||
//vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
|
||||
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
|
||||
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
@ -656,52 +722,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
|||
}
|
||||
}
|
||||
|
||||
//template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int TILE_SIZE>
|
||||
//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store)
|
||||
//{
|
||||
//
|
||||
// const int n_load = n_store/2;
|
||||
// const int base_idx = (blockIdx.x * TILE_SIZE);
|
||||
//
|
||||
// T vals[NUM_PER_TH*2];
|
||||
// unsigned char qvals[NUM_PER_TH];
|
||||
//
|
||||
// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE;
|
||||
// int idx = base_idx + (threadIdx.x*NUM_PER_TH);
|
||||
//
|
||||
// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]);
|
||||
//
|
||||
// if(valid_items == TILE_SIZE)
|
||||
// {
|
||||
// // we do 64 byte loads so we can 128 byte stores
|
||||
// reinterpret_cast<int2(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int2*>(A)[idx/8];
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// #pragma unroll
|
||||
// for(int j = 0; j < NUM_PER_TH; j++)
|
||||
// if(idx+j < n_load)
|
||||
// qvals[j] = A[idx+j];
|
||||
// else
|
||||
// qvals[j] = 0;
|
||||
// }
|
||||
//
|
||||
//
|
||||
// #pragma unroll NUM_PER_TH
|
||||
// for(int j = 0; j < NUM_PER_TH; j++)
|
||||
// {
|
||||
// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f);
|
||||
// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f);
|
||||
// }
|
||||
//
|
||||
//
|
||||
// reinterpret_cast<int4(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int4*>(A)[idx/8];
|
||||
// reinterpret_cast<int4*>(A)[idx/16] = reinterpret_cast<int4(&)[16]>(local_valC)[j/num_items];
|
||||
//
|
||||
//
|
||||
//}
|
||||
|
||||
|
||||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
|
||||
{
|
||||
const unsigned int numThreads = blockDim.x * gridDim.x;
|
||||
|
|
|
@ -2246,8 +2246,10 @@ def test_fp4_quant():
|
|||
|
||||
err = (A1 - A2).abs().float()
|
||||
relerr = (err/A1.abs().float()).mean()
|
||||
idx = err > 1.0
|
||||
err = err.mean()
|
||||
|
||||
|
||||
assert err.item() < 0.1
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
|
@ -2256,7 +2258,7 @@ def test_fp4_compressed_stats():
|
|||
for blocksize in [128, 64]:
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
for i in range(10):
|
||||
for i in range(10000):
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize)
|
||||
q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True)
|
||||
|
@ -2268,7 +2270,7 @@ def test_fp4_compressed_stats():
|
|||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||
err = err.mean()
|
||||
|
||||
errs1.append(err.item())
|
||||
errs1.append(relerr.item())
|
||||
|
||||
assert err.item() < 0.11
|
||||
assert relerr.item() < 0.28
|
||||
|
@ -2277,7 +2279,7 @@ def test_fp4_compressed_stats():
|
|||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||
err = err.mean()
|
||||
|
||||
errs2.append(err.item())
|
||||
errs2.append(relerr.item())
|
||||
|
||||
assert err.item() < 0.11
|
||||
assert relerr.item() < 0.28
|
||||
|
@ -2301,7 +2303,7 @@ def test_bench_fp4_dequant():
|
|||
#print(max_theoretical_s*1e6)
|
||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||
|
||||
iters = 5
|
||||
iters = 500
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
|
|
Loading…
Reference in New Issue
Block a user