Generalized FP4 data type.

This commit is contained in:
Tim Dettmers 2023-04-02 12:42:01 -07:00
parent 51a21df728
commit 2dd5d69056
2 changed files with 88 additions and 66 deletions

View File

@ -64,6 +64,33 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
}
}
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 111
return 0.25000000f*absmax*sign; // 1111
else
return 0.16666667f*absmax*sign; // 1110
else
if((val & 0b0001) == 1) // 110
return 0.50000000f*absmax*sign; // 1101
else
return 0.33333333f*absmax*sign; // 1100
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 1.00000000f*absmax*sign; // 1011
else
return 0.66666667f*absmax*sign; // 1010
else
if((val & 0b0001) == 1) // 100
return 5.208333333e-03f*absmax*sign; // 1001
else
return 0.00000000f*absmax*sign; // 1000
}
__device__ unsigned char dQuantizeFP4(float x)
{
// FP4 with bias of 3
@ -78,42 +105,79 @@ __device__ unsigned char dQuantizeFP4(float x)
// 0b010 = 8
// 0b011 = 12
// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assum input data is in [-1.0, 1.0]
// !be careful here, its easy to make a mistake
// that is difficult to noice if you add an extra
// zero somewhere!
int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if(x > 0.29166667f)
if( x > 0.583333f)
if( x > 0.8333333f)
return 0b0011+sign;
else
return 0b0010+sign;
else
if(x > 0.4166667f)
return 0b101+sign;
else
return 0b100+sign;
else
if(x > 0.0859375f)
if(x > 0.20833333f)
return 0b0111+sign;
else
return 0b0110+sign;
else
if(x > 0.00260417f)
return 0b0001+sign;
else
return 0b0000+sign;
}
__device__ unsigned char dQuantizeNormal(float x)
{
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12
int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if(x > 3.5f)
{
if( x > 7.0f)
{
if( x > 10.0f)
return 0b0011+sign;
else
return 0b0010+sign;
}
else
{
if(x > 5.0f)
return 0b101+sign;
else
return 0b100+sign;
}
}
else
{
if(x > 1.03125f)
{
if(x > 2.5f)
return 0b0111+sign;
else
return 0b0110+sign;
}
else
{
if(x > 0.03125f)
return 0b0001+sign;
else
return 0b0000+sign;
}
}
}
template <int STOCHASTIC>
@ -575,8 +639,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
for(int j = 0; j < NUM_PER_TH/2; j++)
{
unsigned char packed_fp4 = 0;
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4;
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f);
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_fp4;
}
}
@ -639,8 +703,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
//vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
//vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
}
}
else
@ -656,52 +722,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
}
}
//template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int TILE_SIZE>
//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store)
//{
//
// const int n_load = n_store/2;
// const int base_idx = (blockIdx.x * TILE_SIZE);
//
// T vals[NUM_PER_TH*2];
// unsigned char qvals[NUM_PER_TH];
//
// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE;
// int idx = base_idx + (threadIdx.x*NUM_PER_TH);
//
// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]);
//
// if(valid_items == TILE_SIZE)
// {
// // we do 64 byte loads so we can 128 byte stores
// reinterpret_cast<int2(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int2*>(A)[idx/8];
// }
// else
// {
// #pragma unroll
// for(int j = 0; j < NUM_PER_TH; j++)
// if(idx+j < n_load)
// qvals[j] = A[idx+j];
// else
// qvals[j] = 0;
// }
//
//
// #pragma unroll NUM_PER_TH
// for(int j = 0; j < NUM_PER_TH; j++)
// {
// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f);
// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f);
// }
//
//
// reinterpret_cast<int4(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int4*>(A)[idx/8];
// reinterpret_cast<int4*>(A)[idx/16] = reinterpret_cast<int4(&)[16]>(local_valC)[j/num_items];
//
//
//}
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
{
const unsigned int numThreads = blockDim.x * gridDim.x;

View File

@ -2246,8 +2246,10 @@ def test_fp4_quant():
err = (A1 - A2).abs().float()
relerr = (err/A1.abs().float()).mean()
idx = err > 1.0
err = err.mean()
assert err.item() < 0.1
assert relerr.item() < 0.28
@ -2256,7 +2258,7 @@ def test_fp4_compressed_stats():
for blocksize in [128, 64]:
errs1 = []
errs2 = []
for i in range(10):
for i in range(10000):
A1 = torch.randn(1024, 1024, device='cuda').half()
q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize)
q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True)
@ -2268,7 +2270,7 @@ def test_fp4_compressed_stats():
relerr = (err/(A1.abs().float()+1e-15)).mean()
err = err.mean()
errs1.append(err.item())
errs1.append(relerr.item())
assert err.item() < 0.11
assert relerr.item() < 0.28
@ -2277,7 +2279,7 @@ def test_fp4_compressed_stats():
relerr = (err/(A1.abs().float()+1e-15)).mean()
err = err.mean()
errs2.append(err.item())
errs2.append(relerr.item())
assert err.item() < 0.11
assert relerr.item() < 0.28
@ -2301,7 +2303,7 @@ def test_bench_fp4_dequant():
#print(max_theoretical_s*1e6)
b = torch.randn(128, 1024*12, device='cuda').half()
iters = 5
iters = 500
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):