4-bit draft; 128 vector load 240.
This commit is contained in:
parent
869b7e83b5
commit
264a948539
|
@ -1385,10 +1385,12 @@ def cutlass3_gemm(
|
|||
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
if state is None:
|
||||
Bshape = B.shape
|
||||
bout = Bshape[1]
|
||||
else:
|
||||
Bshape = state[1]
|
||||
bout = Bshape[0]
|
||||
if out is None:
|
||||
out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device)
|
||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
|
||||
sA = A.shape
|
||||
sB = B.shape
|
||||
|
@ -1464,7 +1466,7 @@ def cutlass3_gemm(
|
|||
if state is not None:
|
||||
m = Bshape[0]
|
||||
k = Bshape[1]
|
||||
lda = Bshape[1]
|
||||
lda = Bshape[0]
|
||||
ldc = Bshape[0]
|
||||
ldb = (ldb+1)//2
|
||||
#print(m, n, k, lda, ldb, ldc)
|
||||
|
|
307
csrc/kernels.cu
307
csrc/kernels.cu
|
@ -3044,22 +3044,15 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
|
|||
#define WARPS 5
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
{
|
||||
|
||||
typedef cub::WarpReduce<half> WarpReduce;
|
||||
// Allocate WarpReduce shared memory for one warp
|
||||
//__shared__ typename WarpReduce::TempStorage temp_storage;
|
||||
|
||||
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
||||
//// Allocate shared memory for BlockReduce
|
||||
//__shared__ typename BlockReduce::TempStorage reduce;
|
||||
int col_offset = blockIdx.x *32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int half_warp_id = threadIdx.x / 16;
|
||||
const int half_warp_lane = threadIdx.x % 16;
|
||||
const int batch_size_warps = (WARPS-1)*2;
|
||||
const int val_per_iter = blockDim.x-32;
|
||||
|
||||
T local_A[2];
|
||||
T local_B[64];
|
||||
T local_A[4];
|
||||
T local_B[128];
|
||||
|
||||
const int a_tile_offset = 16;
|
||||
const int b_tile_offset = (16*32 + 16);
|
||||
|
@ -3082,24 +3075,45 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
if(loaded_values == 0)
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
local_A[1] = A[idx+blockDim.x-32];
|
||||
local_A[1] = A[idx+(1*val_per_iter)];
|
||||
local_A[2] = A[idx+(2*val_per_iter)];
|
||||
local_A[3] = A[idx+(3*val_per_iter)];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
{
|
||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
|
||||
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
|
||||
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
|
||||
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
|
||||
}
|
||||
loaded_values = 1;
|
||||
loaded_values = 3;
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
loaded_values--;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+32];
|
||||
if(loaded_values == 3)
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+(32)];
|
||||
}
|
||||
else if(loaded_values == 2)
|
||||
{
|
||||
local_A[0] = local_A[2];
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+(64)];
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[3];
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+(96)];
|
||||
}
|
||||
loaded_values--;
|
||||
}
|
||||
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
@ -3139,26 +3153,46 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
if(loaded_values == 0)
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
local_A[1] = A[idx+blockDim.x-32];
|
||||
local_A[1] = A[idx+(1*val_per_iter)];
|
||||
local_A[2] = A[idx+(2*val_per_iter)];
|
||||
local_A[3] = A[idx+(3*val_per_iter)];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
{
|
||||
local_B[col] = B[(col_offset+col)*ldb+idx];
|
||||
local_B[col+32] = B[(col_offset+col)*ldb+idx+blockDim.x-32];
|
||||
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
|
||||
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
|
||||
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
|
||||
}
|
||||
loaded_values = 1;
|
||||
loaded_values = 3;
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
|
||||
if(loaded_values == 3)
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+(32)];
|
||||
}
|
||||
else if(loaded_values == 2)
|
||||
{
|
||||
local_A[0] = local_A[2];
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+(64)];
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[3];
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+(96)];
|
||||
}
|
||||
loaded_values--;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = local_B[col+32];
|
||||
|
||||
|
||||
}
|
||||
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
@ -3215,104 +3249,166 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage reduce;
|
||||
int col_offset = blockIdx.x *8;
|
||||
int col_offset = blockIdx.x *32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int half_warp_id = threadIdx.x / 16;
|
||||
const int half_warp_lane = threadIdx.x % 16;
|
||||
const int batch_size_warps = (WARPS-1)*2;
|
||||
|
||||
T local_A[32];
|
||||
unsigned char local_B_4bit[16];
|
||||
T local_B[32];
|
||||
T local_C[8];
|
||||
T local_A[2];
|
||||
T local_B[64];
|
||||
unsigned char local_B_4bit[32];
|
||||
|
||||
__shared__ T smem_C[8];
|
||||
const int a_tile_offset = 16;
|
||||
const int b_tile_offset = (16*32 + 16);
|
||||
|
||||
if(threadIdx.x < 8)
|
||||
smem_C[threadIdx.x] = T(0);
|
||||
__syncthreads();
|
||||
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
|
||||
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
|
||||
//__shared__ T smem_C[8*32];
|
||||
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
local_C[k] = T(0);
|
||||
wmma::fragment<wmma::matrix_a, 8, 32, 16, half, wmma::row_major> a_frag;
|
||||
wmma::fragment<wmma::matrix_b, 8, 32, 16, half, wmma::col_major> b_frag;
|
||||
wmma::fragment<wmma::accumulator, 8, 32, 16, half> c_frag;
|
||||
wmma::fill_fragment(c_frag, 0.0f);
|
||||
|
||||
|
||||
for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32)
|
||||
int ticktock = 0;
|
||||
int idx = 0 + threadIdx.x;
|
||||
int loaded_values = 0;
|
||||
// prefetch
|
||||
if(idx < K && warp_id < (WARPS-1))
|
||||
{
|
||||
|
||||
// we load only 8 values per iteration from A, so we
|
||||
// need to do 4 loads for every single load from B
|
||||
// for B, we have packed values, so the 16 8-bit values
|
||||
// turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads
|
||||
vector_load<T, int4, 8>(local_A, A, idx, idx, K);
|
||||
vector_load<T, int4, 8>(&(local_A[8]), A, idx+8, idx+8, K);
|
||||
vector_load<T, int4, 8>(&(local_A[16]), A, idx+16, idx+16, K);
|
||||
vector_load<T, int4, 8>(&(local_A[24]), A, idx+24, idx+24, K);
|
||||
|
||||
for(int col = 0; col < 8; col++)
|
||||
if(loaded_values == 0)
|
||||
{
|
||||
if((col + col_offset) >= M){ break; }
|
||||
|
||||
int offset_B = (col_offset+col)*ldb;
|
||||
// 0111 -> 0.0f in NF4
|
||||
// since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111
|
||||
vector_load<unsigned char, int4, 16>(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111);
|
||||
|
||||
int absidx = (idx + offset_B)/blocksize;
|
||||
half local_absmax = __ldg(&(absmax[absidx]));
|
||||
//for(int k = 0; k < 16; k++)
|
||||
//printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F);
|
||||
//printf("\n");
|
||||
|
||||
//vector_load<T, int4, 8>(local_A, A, idx, idx, K);
|
||||
|
||||
#pragma unroll 16
|
||||
for(int k = 0; k < 16; k++)
|
||||
{
|
||||
|
||||
//if(local_B_4bit[k ] != 0b01110111)
|
||||
//printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax),
|
||||
//local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax));
|
||||
//local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax;
|
||||
//local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax;
|
||||
local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax;
|
||||
local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax;
|
||||
//local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax;
|
||||
//local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax;
|
||||
}
|
||||
local_A[0] = A[idx];
|
||||
local_A[1] = A[idx+blockDim.x-32];
|
||||
|
||||
#pragma unroll 32
|
||||
//for(int k = 0; k < 8; k++)
|
||||
for(int k = 0; k < 32; k++)
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
|
||||
|
||||
loaded_values = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
loaded_values--;
|
||||
|
||||
#pragma unroll 64
|
||||
for(int col = 0; col < 64; col+=2)
|
||||
{
|
||||
local_C[col] += local_A[k]*local_B[k];
|
||||
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
|
||||
//if((float)local_B[k] != 0.0)
|
||||
//printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]);
|
||||
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
|
||||
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
|
||||
}
|
||||
else if(warp_id < (WARPS-1))
|
||||
{
|
||||
local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum());
|
||||
local_A[0] = T(0.0);
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = 0.0f;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
|
||||
}
|
||||
ticktock = ticktock == 0 ? 1 : 0;
|
||||
|
||||
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
|
||||
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
|
||||
{
|
||||
idx = base_idx + threadIdx.x;
|
||||
|
||||
__syncthreads();
|
||||
if(idx < K && warp_id < (WARPS-1))
|
||||
{
|
||||
if(loaded_values == 0)
|
||||
{
|
||||
local_A[0] = A[idx];
|
||||
local_A[1] = A[idx+blockDim.x-32];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
{
|
||||
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
|
||||
local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx];
|
||||
}
|
||||
|
||||
loaded_values = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
local_A[0] = local_A[1];
|
||||
loaded_values--;
|
||||
|
||||
int absidx = (idx + col_offset)/blocksize;
|
||||
half local_absmax = __ldg(&(absmax[absidx]));
|
||||
|
||||
#pragma unroll 64
|
||||
for(int col = 0; col < 64; col+=2)
|
||||
{
|
||||
local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
|
||||
local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
|
||||
}
|
||||
}
|
||||
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
|
||||
}
|
||||
else if(warp_id < (WARPS-1))
|
||||
{
|
||||
local_A[0] = T(0.0);
|
||||
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
local_B[col] = 0.0f;
|
||||
|
||||
#pragma unroll 32
|
||||
for(int col = 0; col < 32; col++)
|
||||
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
|
||||
}
|
||||
ticktock = ticktock == 0 ? 1 : 0;
|
||||
|
||||
if(warp_id == (WARPS-1))
|
||||
for(int k = 0; k < batch_size_warps; k++)
|
||||
{
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
}
|
||||
}
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
__syncthreads();
|
||||
if(warp_id != (WARPS-1)){ return; }
|
||||
// only warp_id == (WARPS-1) from here
|
||||
int warp_lane = threadIdx.x % 32;
|
||||
|
||||
ticktock = ticktock == 0 ? 1 : 0;
|
||||
for(int k = 0; k < batch_size_warps; k++)
|
||||
{
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
smem_C[k] = local_C[k];
|
||||
wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
|
||||
wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
|
||||
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
|
||||
}
|
||||
else if(threadIdx.x >= 32)
|
||||
// early return for unused warps
|
||||
return;
|
||||
|
||||
__syncwarp();
|
||||
// 129 mu
|
||||
if(warp_id == (WARPS-1))
|
||||
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
|
||||
|
||||
|
||||
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
|
||||
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
|
||||
if(col_offset + warp_lane < M)
|
||||
out[col_offset + warp_lane] = smem_A[warp_lane];
|
||||
}
|
||||
|
||||
//#define ROWS 2
|
||||
|
@ -3513,6 +3609,7 @@ template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * _
|
|||
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
|
||||
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
|
||||
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
|
|
18
csrc/ops.cu
18
csrc/ops.cu
|
@ -703,17 +703,17 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
|
|||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
int num_blocks = (m+7)/8;
|
||||
int num_blocks = (m+31)/32;
|
||||
|
||||
cout << num_blocks << endl;
|
||||
cout << lda << endl;
|
||||
cout << ldb << endl;
|
||||
cout << ldc << endl;
|
||||
//cout << num_blocks << endl;
|
||||
//cout << lda << endl;
|
||||
//cout << ldb << endl;
|
||||
//cout << ldc << endl;
|
||||
|
||||
cout << m << endl;
|
||||
cout << n << endl;
|
||||
cout << k << endl;
|
||||
kgemm_4bit_inference<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//cout << m << endl;
|
||||
//cout << n << endl;
|
||||
//cout << k << endl;
|
||||
kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
|
|
|
@ -2358,20 +2358,19 @@ def test_normal_map_tree():
|
|||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_cutlass3_gemm(dtype):
|
||||
for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
debug = True
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [4096]:
|
||||
for dim in [4096]:
|
||||
#for dim in [128+1]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
for i in range(100):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
|
@ -2397,7 +2396,7 @@ def test_cutlass3_gemm(dtype):
|
|||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
#if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
# print('')
|
||||
# print(i, err, relerr)
|
||||
# print(A.flatten()[-6:])
|
||||
|
@ -2412,7 +2411,7 @@ def test_cutlass3_gemm(dtype):
|
|||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=True)
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
|
@ -2422,29 +2421,73 @@ def test_cutlass3_gemm(dtype):
|
|||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_gemm_4bit(dtype):
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#torch.random.manual_seed(17)
|
||||
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [32]:
|
||||
for dim in [4096]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B)
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
|
||||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
|
||||
C3 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
|
||||
C1 = torch.matmul(A, B.t())
|
||||
#C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
#print(C1)
|
||||
#print(C2)
|
||||
print(C1.shape, C2.shape)
|
||||
|
||||
#torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005)
|
||||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2)
|
||||
mag = torch.abs(C1)+1e-8
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
print('')
|
||||
print(i, err, relerr)
|
||||
print(A.flatten()[-6:])
|
||||
print(B.flatten()[-6:])
|
||||
out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
print(out)
|
||||
print(out[:-1].sum())
|
||||
print('='*80)
|
||||
print(C1.flatten()[-6:])
|
||||
print(C2.flatten()[-6:])
|
||||
#assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
def test_pipeline_func():
|
||||
a = torch.rand(2, 4).cuda()
|
||||
|
|
Loading…
Reference in New Issue
Block a user