Vectorized loads, conflict free NF4; 52 vs 172.

This commit is contained in:
Tim Dettmers 2023-07-04 15:20:10 -07:00
parent f89ff93e26
commit dfe6900b94
3 changed files with 44 additions and 46 deletions

View File

@ -3519,7 +3519,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
out[col_offset + warp_lane] = smem_C[warp_lane];
}
#define num_values_4bit 16
#define num_values_4bit 32
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
@ -3529,72 +3529,68 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
// 4 warps -> 4 loads per iter
// 1x128 * 128x4 -> 1x4 outputs
typedef cub::WarpReduce<T> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[4];
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
const int warp_idx = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32;
const int row_B = 4*blockIdx.x + warp_idx;
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
const int num_values_8bit = num_values_4bit/2;
T local_C = T(0);
T quant_map[16];
#pragma unroll 16
for(int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
unsigned char local_B_4bit[num_values_4bit/2];
unsigned char local_B_4bit[num_values_8bit];
T local_B[num_values_4bit];
T local_A[num_values_4bit];
__shared__ half quant_map[16*THREADS];
// need to increase occupancy by splitting the rows, but can be done later
for(int i = 0; i < 16; i++)
quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
__syncthreads();
// A: [1, K]
// B: [N, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
{
int offset_B = ldb*row_B + (inner_idx/2);
int absidx = (2*offset_B)/blocksize;
int inner_idx_halved = inner_idx/2;
int offset_B = ldb*row_B;
int absidx = ((2*offset_B)+inner_idx)/blocksize;
T local_absmax = __ldg(&(absmax[absidx]));
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
#pragma unroll
for(int k = 0; k < num_values_4bit/2; k++)
if(row_B < M)
{
if((inner_idx/2) < K && row_B < M)
local_B_4bit[k] = B[offset_B + k];
if((inner_idx_halved + num_values_8bit) < K)
{
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
}
else
local_B_4bit[k] = 0b01110111;
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
if((inner_idx/2) + j < K)
local_B_4bit[j] = 0b01110111;
}
}
//if(row_B < M)
//{
// if((inner_idx/num_values_4bit) < K)
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
// else
// {
// for(int k = 0; k < num_values_4bit/2; k++)
// {
// if((inner_idx/2) < K && row_B < M)
// local_B_4bit[k] = B[offset_B + k];
// else
// local_B_4bit[k] = 0b01110111;
// }
// }
//}
#pragma unroll
for(int k = 0; k < num_values_4bit; k++)
{
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
local_B[k*2] = quant_map[(local_B_4bit[k] >> 4)*THREADS+threadIdx.x]*local_absmax;
local_B[k*2+ 1] = quant_map[(local_B_4bit[k] & 0x0F)*THREADS+threadIdx.x]*local_absmax;
}
//printnonzero<T>(local_B, 4, "B values: ");
if(inner_idx+num_values_4bit)
{
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 0];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 1];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 2];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 3];
}
else
for(int k = 0; k < num_values_4bit; k++)
local_A[k] = A[inner_idx +k];
#pragma unroll
for(int k = 0; k < num_values_4bit; k++)
local_C += A[inner_idx + k]*local_B[k];
local_C += local_A[k]*local_B[k];
}
@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 32>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);

View File

@ -733,6 +733,7 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
{
int num_blocks = (m+3)/4;
//int num_blocks = m;
cout << num_blocks << endl;
//cout << lda << endl;

View File

@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for dim in [4096]:
for dim in [2*4096]:
#for dim in [5120]:
#for dim in [6656]:
#for dim in [128]:
#for dim in [4]:
errs = []
relerrs = []
max_err = 0
max_relerr = 0
for i in range(1):
for i in range(100):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A = torch.randn(1, dim+2, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')