Generalized FP4 data type.
This commit is contained in:
parent
51a21df728
commit
2dd5d69056
144
csrc/kernels.cu
144
csrc/kernels.cu
|
@ -64,6 +64,33 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
|
||||||
|
{
|
||||||
|
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
|
||||||
|
if((val & 0b0100) == 4) // 0
|
||||||
|
if((val & 0b0010) == 2) //01
|
||||||
|
if((val & 0b0001) == 1) // 111
|
||||||
|
return 0.25000000f*absmax*sign; // 1111
|
||||||
|
else
|
||||||
|
return 0.16666667f*absmax*sign; // 1110
|
||||||
|
else
|
||||||
|
if((val & 0b0001) == 1) // 110
|
||||||
|
return 0.50000000f*absmax*sign; // 1101
|
||||||
|
else
|
||||||
|
return 0.33333333f*absmax*sign; // 1100
|
||||||
|
else
|
||||||
|
if((val & 0b0010) == 2) //10
|
||||||
|
if((val & 0b0001) == 1) // 101
|
||||||
|
return 1.00000000f*absmax*sign; // 1011
|
||||||
|
else
|
||||||
|
return 0.66666667f*absmax*sign; // 1010
|
||||||
|
else
|
||||||
|
if((val & 0b0001) == 1) // 100
|
||||||
|
return 5.208333333e-03f*absmax*sign; // 1001
|
||||||
|
else
|
||||||
|
return 0.00000000f*absmax*sign; // 1000
|
||||||
|
}
|
||||||
|
|
||||||
__device__ unsigned char dQuantizeFP4(float x)
|
__device__ unsigned char dQuantizeFP4(float x)
|
||||||
{
|
{
|
||||||
// FP4 with bias of 3
|
// FP4 with bias of 3
|
||||||
|
@ -78,42 +105,79 @@ __device__ unsigned char dQuantizeFP4(float x)
|
||||||
// 0b010 = 8
|
// 0b010 = 8
|
||||||
// 0b011 = 12
|
// 0b011 = 12
|
||||||
|
|
||||||
|
|
||||||
|
// we do a binary search
|
||||||
|
// the pivots are divided by 12 (the FP4 absmax)
|
||||||
|
// since we assum input data is in [-1.0, 1.0]
|
||||||
|
|
||||||
|
// !be careful here, its easy to make a mistake
|
||||||
|
// that is difficult to noice if you add an extra
|
||||||
|
// zero somewhere!
|
||||||
|
|
||||||
|
int sign = x < 0 ? 0b1000 : 0b0000;
|
||||||
|
x = fabsf(x);
|
||||||
|
if(x > 0.29166667f)
|
||||||
|
if( x > 0.583333f)
|
||||||
|
if( x > 0.8333333f)
|
||||||
|
return 0b0011+sign;
|
||||||
|
else
|
||||||
|
return 0b0010+sign;
|
||||||
|
else
|
||||||
|
if(x > 0.4166667f)
|
||||||
|
return 0b101+sign;
|
||||||
|
else
|
||||||
|
return 0b100+sign;
|
||||||
|
else
|
||||||
|
if(x > 0.0859375f)
|
||||||
|
if(x > 0.20833333f)
|
||||||
|
return 0b0111+sign;
|
||||||
|
else
|
||||||
|
return 0b0110+sign;
|
||||||
|
else
|
||||||
|
if(x > 0.00260417f)
|
||||||
|
return 0b0001+sign;
|
||||||
|
else
|
||||||
|
return 0b0000+sign;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ unsigned char dQuantizeNormal(float x)
|
||||||
|
{
|
||||||
|
// FP4 with bias of 3
|
||||||
|
// first bit is a sign
|
||||||
|
// subnormals
|
||||||
|
// 0b000 = 0
|
||||||
|
// 0b001 = 0.0625
|
||||||
|
// 0b110 = 2
|
||||||
|
// 0b111 = 3
|
||||||
|
// 0b100 = 4
|
||||||
|
// 0b101 = 6
|
||||||
|
// 0b010 = 8
|
||||||
|
// 0b011 = 12
|
||||||
|
|
||||||
int sign = x < 0 ? 0b1000 : 0b0000;
|
int sign = x < 0 ? 0b1000 : 0b0000;
|
||||||
x = fabsf(x);
|
x = fabsf(x);
|
||||||
if(x > 3.5f)
|
if(x > 3.5f)
|
||||||
{
|
|
||||||
if( x > 7.0f)
|
if( x > 7.0f)
|
||||||
{
|
|
||||||
if( x > 10.0f)
|
if( x > 10.0f)
|
||||||
return 0b0011+sign;
|
return 0b0011+sign;
|
||||||
else
|
else
|
||||||
return 0b0010+sign;
|
return 0b0010+sign;
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
|
||||||
if(x > 5.0f)
|
if(x > 5.0f)
|
||||||
return 0b101+sign;
|
return 0b101+sign;
|
||||||
else
|
else
|
||||||
return 0b100+sign;
|
return 0b100+sign;
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
|
||||||
if(x > 1.03125f)
|
if(x > 1.03125f)
|
||||||
{
|
|
||||||
if(x > 2.5f)
|
if(x > 2.5f)
|
||||||
return 0b0111+sign;
|
return 0b0111+sign;
|
||||||
else
|
else
|
||||||
return 0b0110+sign;
|
return 0b0110+sign;
|
||||||
}
|
|
||||||
else
|
else
|
||||||
{
|
|
||||||
if(x > 0.03125f)
|
if(x > 0.03125f)
|
||||||
return 0b0001+sign;
|
return 0b0001+sign;
|
||||||
else
|
else
|
||||||
return 0b0000+sign;
|
return 0b0000+sign;
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int STOCHASTIC>
|
template <int STOCHASTIC>
|
||||||
|
@ -575,8 +639,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
||||||
for(int j = 0; j < NUM_PER_TH/2; j++)
|
for(int j = 0; j < NUM_PER_TH/2; j++)
|
||||||
{
|
{
|
||||||
unsigned char packed_fp4 = 0;
|
unsigned char packed_fp4 = 0;
|
||||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max*12.0f) << 4;
|
packed_fp4 |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
|
||||||
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max*12.0f);
|
packed_fp4 |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
|
||||||
qvals[j] = packed_fp4;
|
qvals[j] = packed_fp4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -639,8 +703,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
||||||
#pragma unroll NUM_PER_TH
|
#pragma unroll NUM_PER_TH
|
||||||
for(int j = 0; j < NUM_PER_TH; j++)
|
for(int j = 0; j < NUM_PER_TH; j++)
|
||||||
{
|
{
|
||||||
vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
|
//vals[j*2] = dDequantizeFP4(qvals[j] >> 4, local_abs_max*0.083333f);
|
||||||
vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
|
//vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*0.083333);
|
||||||
|
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
|
||||||
|
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -656,52 +722,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int TILE_SIZE>
|
|
||||||
//__global__ void kDequantizeBlockwiseFP4(unsigned char * A, float * absmax, T *out, const int n_store)
|
|
||||||
//{
|
|
||||||
//
|
|
||||||
// const int n_load = n_store/2;
|
|
||||||
// const int base_idx = (blockIdx.x * TILE_SIZE);
|
|
||||||
//
|
|
||||||
// T vals[NUM_PER_TH*2];
|
|
||||||
// unsigned char qvals[NUM_PER_TH];
|
|
||||||
//
|
|
||||||
// int valid_items = (base_idx + TILE_SIZE) > n_load ? ((base_idx+TILE_SIZE) - n_load) : TILE_SIZE;
|
|
||||||
// int idx = base_idx + (threadIdx.x*NUM_PER_TH);
|
|
||||||
//
|
|
||||||
// float local_abs_max = __ldg(&absmax[idx/BLOCK_SIZE]);
|
|
||||||
//
|
|
||||||
// if(valid_items == TILE_SIZE)
|
|
||||||
// {
|
|
||||||
// // we do 64 byte loads so we can 128 byte stores
|
|
||||||
// reinterpret_cast<int2(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int2*>(A)[idx/8];
|
|
||||||
// }
|
|
||||||
// else
|
|
||||||
// {
|
|
||||||
// #pragma unroll
|
|
||||||
// for(int j = 0; j < NUM_PER_TH; j++)
|
|
||||||
// if(idx+j < n_load)
|
|
||||||
// qvals[j] = A[idx+j];
|
|
||||||
// else
|
|
||||||
// qvals[j] = 0;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// #pragma unroll NUM_PER_TH
|
|
||||||
// for(int j = 0; j < NUM_PER_TH; j++)
|
|
||||||
// {
|
|
||||||
// vals[j*2] = dDequantizeFP4(qvals[j] & 0xF0, local_abs_max*12.0f);
|
|
||||||
// vals[j*2 + 1] = dDequantizeFP4(qvals[j] & 0x0F, local_abs_max*12.0f);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// reinterpret_cast<int4(&)[NUM_PER_THREAD/8]>(qvals)[0] = reinterpret_cast<int4*>(A)[idx/8];
|
|
||||||
// reinterpret_cast<int4*>(A)[idx/16] = reinterpret_cast<int4(&)[16]>(local_valC)[j/num_items];
|
|
||||||
//
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
|
|
||||||
|
|
||||||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
|
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
|
||||||
{
|
{
|
||||||
const unsigned int numThreads = blockDim.x * gridDim.x;
|
const unsigned int numThreads = blockDim.x * gridDim.x;
|
||||||
|
|
|
@ -2246,8 +2246,10 @@ def test_fp4_quant():
|
||||||
|
|
||||||
err = (A1 - A2).abs().float()
|
err = (A1 - A2).abs().float()
|
||||||
relerr = (err/A1.abs().float()).mean()
|
relerr = (err/A1.abs().float()).mean()
|
||||||
|
idx = err > 1.0
|
||||||
err = err.mean()
|
err = err.mean()
|
||||||
|
|
||||||
|
|
||||||
assert err.item() < 0.1
|
assert err.item() < 0.1
|
||||||
assert relerr.item() < 0.28
|
assert relerr.item() < 0.28
|
||||||
|
|
||||||
|
@ -2256,7 +2258,7 @@ def test_fp4_compressed_stats():
|
||||||
for blocksize in [128, 64]:
|
for blocksize in [128, 64]:
|
||||||
errs1 = []
|
errs1 = []
|
||||||
errs2 = []
|
errs2 = []
|
||||||
for i in range(10):
|
for i in range(10000):
|
||||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||||
q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize)
|
q2, SA2 = F.quantize_fp4(A1, blocksize=blocksize)
|
||||||
q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True)
|
q3, SA3= F.quantize_fp4(A1, blocksize=blocksize, compress_statistics=True)
|
||||||
|
@ -2268,7 +2270,7 @@ def test_fp4_compressed_stats():
|
||||||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||||
err = err.mean()
|
err = err.mean()
|
||||||
|
|
||||||
errs1.append(err.item())
|
errs1.append(relerr.item())
|
||||||
|
|
||||||
assert err.item() < 0.11
|
assert err.item() < 0.11
|
||||||
assert relerr.item() < 0.28
|
assert relerr.item() < 0.28
|
||||||
|
@ -2277,7 +2279,7 @@ def test_fp4_compressed_stats():
|
||||||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||||
err = err.mean()
|
err = err.mean()
|
||||||
|
|
||||||
errs2.append(err.item())
|
errs2.append(relerr.item())
|
||||||
|
|
||||||
assert err.item() < 0.11
|
assert err.item() < 0.11
|
||||||
assert relerr.item() < 0.28
|
assert relerr.item() < 0.28
|
||||||
|
@ -2301,7 +2303,7 @@ def test_bench_fp4_dequant():
|
||||||
#print(max_theoretical_s*1e6)
|
#print(max_theoretical_s*1e6)
|
||||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||||
|
|
||||||
iters = 5
|
iters = 500
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user