4-bit draft; 128 vector load 240.

This commit is contained in:
Tim Dettmers 2023-05-02 16:15:38 -07:00
parent 869b7e83b5
commit 264a948539
4 changed files with 284 additions and 142 deletions

View File

@ -1385,10 +1385,12 @@ def cutlass3_gemm(
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:
Bshape = B.shape
bout = Bshape[1]
else:
Bshape = state[1]
bout = Bshape[0]
if out is None:
out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device)
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
sA = A.shape
sB = B.shape
@ -1464,7 +1466,7 @@ def cutlass3_gemm(
if state is not None:
m = Bshape[0]
k = Bshape[1]
lda = Bshape[1]
lda = Bshape[0]
ldc = Bshape[0]
ldb = (ldb+1)//2
#print(m, n, k, lda, ldb, ldc)

View File

@ -3044,22 +3044,15 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
#define WARPS 5
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
typedef cub::WarpReduce<half> WarpReduce;
// Allocate WarpReduce shared memory for one warp
//__shared__ typename WarpReduce::TempStorage temp_storage;
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
//// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
const int val_per_iter = blockDim.x-32;
T local_A[2];
T local_B[64];
T local_A[4];
T local_B[128];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
@ -3082,24 +3075,45 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
}
loaded_values = 1;
loaded_values = 3;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+32];
if(loaded_values == 3)
{
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
}
loaded_values--;
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
@ -3139,26 +3153,46 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
}
loaded_values = 1;
loaded_values = 3;
}
else
{
local_A[0] = local_A[1];
if(loaded_values == 3)
{
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
}
loaded_values--;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+32];
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
@ -3215,104 +3249,166 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
typedef cub::BlockReduce<T, THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage reduce;
int col_offset = blockIdx.x *8;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
T local_A[32];
unsigned char local_B_4bit[16];
T local_B[32];
T local_C[8];
T local_A[2];
T local_B[64];
unsigned char local_B_4bit[32];
__shared__ T smem_C[8];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
if(threadIdx.x < 8)
smem_C[threadIdx.x] = T(0);
__syncthreads();
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
//__shared__ T smem_C[8*32];
#pragma unroll 8
for(int k = 0; k < 8; k++)
local_C[k] = T(0);
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
wmma::fill_fragment(c_frag, 0.0f);
for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32)
int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
// we load only 8 values per iteration from A, so we
// need to do 4 loads for every single load from B
// for B, we have packed values, so the 16 8-bit values
// turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads
vector_load<T, int4, 8>(local_A, A, idx, idx, K);
vector_load<T, int4, 8>(&(local_A[8]), A, idx+8, idx+8, K);
vector_load<T, int4, 8>(&(local_A[16]), A, idx+16, idx+16, K);
vector_load<T, int4, 8>(&(local_A[24]), A, idx+24, idx+24, K);
for(int col = 0; col < 8; col++)
if(loaded_values == 0)
{
if((col + col_offset) >= M){ break; }
int offset_B = (col_offset+col)*ldb;
// 0111 -> 0.0f in NF4
// since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111
vector_load<unsigned char, int4, 16>(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111);
int absidx = (idx + offset_B)/blocksize;
half local_absmax = __ldg(&(absmax[absidx]));
//for(int k = 0; k < 16; k++)
//printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F);
//printf("\n");
//vector_load<T, int4, 8>(local_A, A, idx, idx, K);
#pragma unroll 16
for(int k = 0; k < 16; k++)
{
//if(local_B_4bit[k ] != 0b01110111)
//printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax),
//local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax));
//local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax;
//local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax;
local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax;
local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax;
//local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax;
//local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax;
}
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32
//for(int k = 0; k < 8; k++)
for(int k = 0; k < 32; k++)
for(int col = 0; col < 32; col++)
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
local_C[col] += local_A[k]*local_B[k];
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
//if((float)local_B[k] != 0.0)
//printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]);
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
}
}
}
#pragma unroll 8
for(int k = 0; k < 8; k++)
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum());
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
idx = base_idx + threadIdx.x;
__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx];
}
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
int absidx = (idx + col_offset)/blocksize;
half local_absmax = __ldg(&(absmax[absidx]));
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
}
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}
if(threadIdx.x == 0)
__syncthreads();
if(warp_id != (WARPS-1)){ return; }
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;
ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
#pragma unroll 8
for(int k = 0; k < 8; k++)
smem_C[k] = local_C[k];
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
else if(threadIdx.x >= 32)
// early return for unused warps
return;
__syncwarp();
// 129 mu
if(warp_id == (WARPS-1))
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
}
//#define ROWS 2
@ -3513,6 +3609,7 @@ template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * _
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);

View File

@ -703,17 +703,17 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
int num_blocks = (m+7)/8;
int num_blocks = (m+31)/32;
cout << num_blocks << endl;
cout << lda << endl;
cout << ldb << endl;
cout << ldc << endl;
//cout << num_blocks << endl;
//cout << lda << endl;
//cout << ldb << endl;
//cout << ldc << endl;
cout << m << endl;
cout << n << endl;
cout << k << endl;
kgemm_4bit_inference<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//cout << m << endl;
//cout << n << endl;
//cout << k << endl;
kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}

View File

@ -2358,20 +2358,19 @@ def test_normal_map_tree():
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_cutlass3_gemm(dtype):
for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
debug = True
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [4096]:
for dim in [4096]:
#for dim in [128+1]:
errs = []
relerrs = []
max_err = 0
max_relerr = 0
for i in range(100):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
@ -2397,7 +2396,7 @@ def test_cutlass3_gemm(dtype):
errs.append(err)
relerrs.append(relerr)
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, relerr)
# print(A.flatten()[-6:])
@ -2412,7 +2411,7 @@ def test_cutlass3_gemm(dtype):
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True)
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
#print(c/math.sqrt(dim))
print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
@ -2422,29 +2421,73 @@ def test_cutlass3_gemm(dtype):
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_gemm_4bit(dtype):
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#torch.random.manual_seed(17)
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for dim in [4096]:
errs = []
relerrs = []
max_err = 0
max_relerr = 0
for i in range(1):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
qB, state = F.quantize_nf4(B)
F.dequantize_nf4(qB, state)
qB, state = F.quantize_nf4(B)
F.dequantize_nf4(qB, state)
C3 = torch.matmul(A, B.t())
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
C1 = bnb.matmul_4bit(A, qB.t(), state)
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
C1 = torch.matmul(A, B.t())
#C1 = bnb.matmul_4bit(A, qB.t(), state)
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
#print(C1)
#print(C2)
print(C1.shape, C2.shape)
#torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005)
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err = torch.abs(C1-C2)
mag = torch.abs(C1)+1e-8
relerr = err/mag
max_err = max(err.max(), max_err)
max_relerr = max(relerr.max(), max_relerr)
err = err.mean().item()
relerr = relerr.mean().item()
errs.append(err)
relerrs.append(relerr)
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
print('')
print(i, err, relerr)
print(A.flatten()[-6:])
print(B.flatten()[-6:])
out = A.flatten()[-6:]*B.flatten()[-6:]
print(out)
print(out[:-1].sum())
print('='*80)
print(C1.flatten()[-6:])
print(C2.flatten()[-6:])
#assert False, 'ERROR'
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
#print(c/math.sqrt(dim))
print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
print(dim, (max_err.item(), max_relerr.item()))
def test_pipeline_func():
a = torch.rand(2, 4).cuda()