Turning optimization (float accumulation). 185 vs 50.
This commit is contained in:
parent
7e49b5b938
commit
eefbf60270
|
@ -3528,29 +3528,26 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
||||||
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
|
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
|
||||||
// 4 warps -> 4 loads per iter
|
// 4 warps -> 4 loads per iter
|
||||||
// 1x128 * 128x4 -> 1x4 outputs
|
// 1x128 * 128x4 -> 1x4 outputs
|
||||||
typedef cub::WarpReduce<T> WarpReduce;
|
//typedef cub::WarpReduce<T> WarpReduce;
|
||||||
|
typedef cub::WarpReduce<float> WarpReduce;
|
||||||
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
|
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
|
||||||
|
|
||||||
const int warp_idx = threadIdx.x / 32;
|
const int warp_idx = threadIdx.x / 32;
|
||||||
const int warp_lane = threadIdx.x % 32;
|
const int warp_lane = threadIdx.x % 32;
|
||||||
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
|
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
|
||||||
const int num_values_8bit = num_values_4bit/2;
|
const int num_values_8bit = num_values_4bit/2;
|
||||||
T local_C = T(0);
|
//T local_C = T(0.0f);
|
||||||
|
float local_C = 0.0f;
|
||||||
T lane_quant_value = nf4_data[warp_lane % 16];
|
|
||||||
|
|
||||||
unsigned char local_B_4bit[num_values_8bit];
|
unsigned char local_B_4bit[num_values_8bit];
|
||||||
T local_B[num_values_4bit];
|
T local_B[num_values_4bit];
|
||||||
T local_A[num_values_4bit];
|
T local_A[num_values_4bit];
|
||||||
__shared__ T quant_map[16*THREADS];
|
__shared__ T quant_map[16];
|
||||||
__shared__ T quant_map2[16];
|
T local_absmax = T(0.0f);
|
||||||
|
|
||||||
//for(int i = 0; i < 16; i++)
|
for(int i = threadIdx.x; i < 16; i++)
|
||||||
// quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
|
quant_map[i] = nf4_data[i];
|
||||||
//__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for(int i = 0; i < 16; i++)
|
|
||||||
quant_map2[i] = nf4_data[i];
|
|
||||||
|
|
||||||
// A: [1, K]
|
// A: [1, K]
|
||||||
// B: [N, K]
|
// B: [N, K]
|
||||||
|
@ -3559,7 +3556,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
||||||
int inner_idx_halved = inner_idx/2;
|
int inner_idx_halved = inner_idx/2;
|
||||||
int offset_B = ldb*row_B;
|
int offset_B = ldb*row_B;
|
||||||
int absidx = ((2*offset_B)+inner_idx)/blocksize;
|
int absidx = ((2*offset_B)+inner_idx)/blocksize;
|
||||||
T local_absmax = __ldg(&(absmax[absidx]));
|
local_absmax = __ldg(&(absmax[absidx]));
|
||||||
|
|
||||||
if(row_B < M)
|
if(row_B < M)
|
||||||
{
|
{
|
||||||
|
@ -3576,25 +3573,11 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(inner_idx+(num_values_4bit*32) < K)
|
#pragma unroll
|
||||||
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
{
|
{
|
||||||
// full warp is running
|
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
||||||
#pragma unroll
|
local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
|
||||||
{
|
|
||||||
local_B[k*2] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] >> 4)*local_absmax;
|
|
||||||
local_B[k*2 + 1] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] & 0x0F)*local_absmax;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// part of the warp exited already
|
|
||||||
#pragma unroll
|
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
|
||||||
{
|
|
||||||
local_B[k*2] = quant_map2[(local_B_4bit[k] >> 4)]*local_absmax;
|
|
||||||
local_B[k*2 + 1] = quant_map2[(local_B_4bit[k] & 0x0F)]*local_absmax;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(inner_idx+num_values_4bit)
|
if(inner_idx+num_values_4bit)
|
||||||
|
@ -3603,6 +3586,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
||||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 1];
|
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 1];
|
||||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 2];
|
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 2];
|
||||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 3];
|
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 3];
|
||||||
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
|
@ -3610,14 +3594,14 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
local_C += local_A[k]*local_B[k];
|
local_C += (float)(local_A[k]*local_B[k]);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
|
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
|
||||||
|
|
||||||
if(row_B < M && warp_lane == 0)
|
if(row_B < M && warp_lane == 0)
|
||||||
out[row_B] = local_C;
|
out[row_B] = T(local_C);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2420,7 +2420,7 @@ def test_cutlass3_gemm(dtype):
|
||||||
def test_gemm_4bit(dtype):
|
def test_gemm_4bit(dtype):
|
||||||
print('')
|
print('')
|
||||||
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
|
||||||
for dim in [4096]:
|
for dim in [4*1024]:
|
||||||
errs = []
|
errs = []
|
||||||
relerrs = []
|
relerrs = []
|
||||||
max_err = 0
|
max_err = 0
|
||||||
|
@ -2485,10 +2485,10 @@ def test_gemm_4bit(dtype):
|
||||||
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||||
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||||
#print(dim, (max_err.item(), max_relerr.item()))
|
#print(dim, (max_err.item(), max_relerr.item()))
|
||||||
#print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
|
print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
|
||||||
#print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
|
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
|
||||||
#assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
|
assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
|
||||||
#assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
|
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
|
||||||
|
|
||||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||||
def test_managed():
|
def test_managed():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user