Vectorized loads, conflict free NF4; 52 vs 172.
This commit is contained in:
parent
f89ff93e26
commit
dfe6900b94
|
@ -3519,7 +3519,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
out[col_offset + warp_lane] = smem_C[warp_lane];
|
||||
}
|
||||
|
||||
#define num_values_4bit 16
|
||||
#define num_values_4bit 32
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
|
@ -3529,72 +3529,68 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
|
|||
// 4 warps -> 4 loads per iter
|
||||
// 1x128 * 128x4 -> 1x4 outputs
|
||||
typedef cub::WarpReduce<T> WarpReduce;
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage[4];
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32];
|
||||
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
const int warp_lane = threadIdx.x % 32;
|
||||
const int row_B = 4*blockIdx.x + warp_idx;
|
||||
const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
|
||||
const int num_values_8bit = num_values_4bit/2;
|
||||
T local_C = T(0);
|
||||
|
||||
T quant_map[16];
|
||||
#pragma unroll 16
|
||||
for(int i = 0; i < 16; i++)
|
||||
quant_map[i] = nf4_data[i];
|
||||
|
||||
unsigned char local_B_4bit[num_values_4bit/2];
|
||||
unsigned char local_B_4bit[num_values_8bit];
|
||||
T local_B[num_values_4bit];
|
||||
T local_A[num_values_4bit];
|
||||
__shared__ half quant_map[16*THREADS];
|
||||
|
||||
// need to increase occupancy by splitting the rows, but can be done later
|
||||
for(int i = 0; i < 16; i++)
|
||||
quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i];
|
||||
__syncthreads();
|
||||
|
||||
// A: [1, K]
|
||||
// B: [N, K]
|
||||
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
|
||||
{
|
||||
int offset_B = ldb*row_B + (inner_idx/2);
|
||||
int absidx = (2*offset_B)/blocksize;
|
||||
int inner_idx_halved = inner_idx/2;
|
||||
int offset_B = ldb*row_B;
|
||||
int absidx = ((2*offset_B)+inner_idx)/blocksize;
|
||||
T local_absmax = __ldg(&(absmax[absidx]));
|
||||
|
||||
//printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B);
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit/2; k++)
|
||||
if(row_B < M)
|
||||
{
|
||||
if((inner_idx/2) < K && row_B < M)
|
||||
local_B_4bit[k] = B[offset_B + k];
|
||||
if((inner_idx_halved + num_values_8bit) < K)
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
|
||||
}
|
||||
else
|
||||
local_B_4bit[k] = 0b01110111;
|
||||
{
|
||||
#pragma unroll
|
||||
for(int j = 0; j < (num_values_8bit); j++)
|
||||
if((inner_idx/2) + j < K)
|
||||
local_B_4bit[j] = 0b01110111;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//if(row_B < M)
|
||||
//{
|
||||
// if((inner_idx/num_values_4bit) < K)
|
||||
// reinterpret_cast<int4*>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[offset_B/(num_values_4bit/2)];
|
||||
// else
|
||||
// {
|
||||
// for(int k = 0; k < num_values_4bit/2; k++)
|
||||
// {
|
||||
// if((inner_idx/2) < K && row_B < M)
|
||||
// local_B_4bit[k] = B[offset_B + k];
|
||||
// else
|
||||
// local_B_4bit[k] = 0b01110111;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
{
|
||||
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
||||
local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
|
||||
local_B[k*2] = quant_map[(local_B_4bit[k] >> 4)*THREADS+threadIdx.x]*local_absmax;
|
||||
local_B[k*2+ 1] = quant_map[(local_B_4bit[k] & 0x0F)*THREADS+threadIdx.x]*local_absmax;
|
||||
}
|
||||
|
||||
//printnonzero<T>(local_B, 4, "B values: ");
|
||||
if(inner_idx+num_values_4bit)
|
||||
{
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 0];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 1];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[2] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 2];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[3] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_8bit/2) + 3];
|
||||
}
|
||||
else
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
local_A[k] = A[inner_idx +k];
|
||||
|
||||
#pragma unroll
|
||||
for(int k = 0; k < num_values_4bit; k++)
|
||||
local_C += A[inner_idx + k]*local_B[k];
|
||||
local_C += local_A[k]*local_B[k];
|
||||
|
||||
}
|
||||
|
||||
|
@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
|
|||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 32>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference_naive<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
|
|
@ -733,6 +733,7 @@ template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A,
|
|||
{
|
||||
|
||||
int num_blocks = (m+3)/4;
|
||||
//int num_blocks = m;
|
||||
|
||||
cout << num_blocks << endl;
|
||||
//cout << lda << endl;
|
||||
|
|
|
@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype):
|
|||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [32]:
|
||||
for dim in [4096]:
|
||||
for dim in [2*4096]:
|
||||
#for dim in [5120]:
|
||||
#for dim in [6656]:
|
||||
#for dim in [128]:
|
||||
#for dim in [4]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
for i in range(1):
|
||||
for i in range(100):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim+2, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
|
|
Loading…
Reference in New Issue
Block a user