Vectorized loads, conflict free NF4; 52 vs 172.
This commit is contained in:
parent
f89ff93e26
commit
dfe6900b94
|
@ -3519,7 +3519,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
||||||
out[col_offset + warp_lane] = smem_C[warp_lane];
|
out[col_offset + warp_lane] = smem_C[warp_lane];
|
||||||
}
|
}
|
||||||
|
|
||||||
#define num_values_4bit 16
|
#define num_values_4bit 32
|
||||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
@ -3529,72 +3529,68 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
||||||
// 4 warps -> 4 loads per iter
|
// 4 warps -> 4 loads per iter
|
||||||
// 1x128 * 128x4 -> 1x4 outputs
|
// 1x128 * 128x4 -> 1x4 outputs
|
||||||
typedef cub::WarpReduce<T> WarpReduce;
|
typedef cub::WarpReduce<T> WarpReduce;
|
||||||
__shared__ typename WarpReduce::TempStorage temp_storage[4];
|
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
|
||||||
|
|
||||||
const int warp_idx = threadIdx.x / 32;
|
const int warp_idx = threadIdx.x / 32;
|
||||||
const int warp_lane = threadIdx.x % 32;
|
const int warp_lane = threadIdx.x % 32;
|
||||||
const int row_B = 4*blockIdx.x + warp_idx;
|
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
|
||||||
|
const int num_values_8bit = num_values_4bit/2;
|
||||||
T local_C = T(0);
|
T local_C = T(0);
|
||||||
|
|
||||||
T quant_map[16];
|
unsigned char local_B_4bit[num_values_8bit];
|
||||||
#pragma unroll 16
|
|
||||||
for(int i = 0; i < 16; i++)
|
|
||||||
quant_map[i] = nf4_data[i];
|
|
||||||
|
|
||||||
unsigned char local_B_4bit[num_values_4bit/2];
|
|
||||||
T local_B[num_values_4bit];
|
T local_B[num_values_4bit];
|
||||||
|
T local_A[num_values_4bit];
|
||||||
|
__shared__ half quant_map[16*THREADS];
|
||||||
|
|
||||||
// need to increase occupancy by splitting the rows, but can be done later
|
for(int i = 0; i < 16; i++)
|
||||||
|
quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// A: [1, K]
|
// A: [1, K]
|
||||||
// B: [N, K]
|
// B: [N, K]
|
||||||
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
|
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
|
||||||
{
|
{
|
||||||
int offset_B = ldb*row_B + (inner_idx/2);
|
int inner_idx_halved = inner_idx/2;
|
||||||
int absidx = (2*offset_B)/blocksize;
|
int offset_B = ldb*row_B;
|
||||||
|
int absidx = ((2*offset_B)+inner_idx)/blocksize;
|
||||||
T local_absmax = __ldg(&(absmax[absidx]));
|
T local_absmax = __ldg(&(absmax[absidx]));
|
||||||
|
|
||||||
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
|
if(row_B < M)
|
||||||
#pragma unroll
|
|
||||||
for(int k = 0; k < num_values_4bit/2; k++)
|
|
||||||
{
|
{
|
||||||
if((inner_idx/2) < K && row_B < M)
|
if((inner_idx_halved + num_values_8bit) < K)
|
||||||
local_B_4bit[k] = B[offset_B + k];
|
{
|
||||||
|
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
|
||||||
|
}
|
||||||
else
|
else
|
||||||
local_B_4bit[k] = 0b01110111;
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for(int j = 0; j < (num_values_8bit); j++)
|
||||||
|
if((inner_idx/2) + j < K)
|
||||||
|
local_B_4bit[j] = 0b01110111;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//if(row_B < M)
|
|
||||||
//{
|
|
||||||
// if((inner_idx/num_values_4bit) < K)
|
|
||||||
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
|
|
||||||
// else
|
|
||||||
// {
|
|
||||||
// for(int k = 0; k < num_values_4bit/2; k++)
|
|
||||||
// {
|
|
||||||
// if((inner_idx/2) < K && row_B < M)
|
|
||||||
// local_B_4bit[k] = B[offset_B + k];
|
|
||||||
// else
|
|
||||||
// local_B_4bit[k] = 0b01110111;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
{
|
{
|
||||||
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
local_B[k*2] = quant_map[(local_B_4bit[k] >> 4)*THREADS+threadIdx.x]*local_absmax;
|
||||||
local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
|
local_B[k*2+ 1] = quant_map[(local_B_4bit[k] & 0x0F)*THREADS+threadIdx.x]*local_absmax;
|
||||||
}
|
}
|
||||||
|
|
||||||
//printnonzero<T>(local_B, 4, "B values: ");
|
if(inner_idx+num_values_4bit)
|
||||||
|
{
|
||||||
|
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 0];
|
||||||
|
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 1];
|
||||||
|
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 2];
|
||||||
|
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 3];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
|
local_A[k] = A[inner_idx +k];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
local_C += A[inner_idx + k]*local_B[k];
|
local_C += local_A[k]*local_B[k];
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
|
||||||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
|
|
||||||
|
template __global__ void kgemm_4bit_inference_naive<half, 32>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||||
|
|
|
@ -733,6 +733,7 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
|
||||||
{
|
{
|
||||||
|
|
||||||
int num_blocks = (m+3)/4;
|
int num_blocks = (m+3)/4;
|
||||||
|
//int num_blocks = m;
|
||||||
|
|
||||||
cout << num_blocks << endl;
|
cout << num_blocks << endl;
|
||||||
//cout << lda << endl;
|
//cout << lda << endl;
|
||||||
|
|
|
@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype):
|
||||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||||
#for dim in [4096, 5120, 6656, 8192]:
|
#for dim in [4096, 5120, 6656, 8192]:
|
||||||
#for dim in [32]:
|
#for dim in [32]:
|
||||||
for dim in [4096]:
|
for dim in [2*4096]:
|
||||||
#for dim in [5120]:
|
#for dim in [5120]:
|
||||||
#for dim in [6656]:
|
#for dim in [6656]:
|
||||||
#for dim in [128]:
|
#for dim in [4]:
|
||||||
errs = []
|
errs = []
|
||||||
relerrs = []
|
relerrs = []
|
||||||
max_err = 0
|
max_err = 0
|
||||||
max_relerr = 0
|
max_relerr = 0
|
||||||
for i in range(1):
|
for i in range(100):
|
||||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||||
A = torch.randn(1, dim+2, dtype=dtype, device='cuda')
|
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||||
B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||||
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||||
|
|
||||||
#print('')
|
#print('')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user