Added debugging functions.

This commit is contained in:
Tim Dettmers 2023-05-30 20:42:21 -07:00
parent b7f04e2a20
commit e54d2730fc
2 changed files with 14 additions and 3 deletions

View File

@ -3297,11 +3297,21 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
#endif
}
template <typename T> __device__ void printnonzero(T *A, int num_values)
{
for(int i = 0; i < num_values; i++)
if((float)A[i] != 0.0)
printf("%i %f\n", i, (float)A[i]);
}
template __device__ void printnonzero<float>(float *A, int num_values);
template __device__ void printnonzero<half>(half *A, int num_values);
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
@ -3469,9 +3479,10 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
if(warp_id == (WARPS-1))
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
printnonzero<T>(smem_A, 32);
if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
#endif
}
//#define ROWS 2

View File

@ -2414,7 +2414,7 @@ def test_gemm_4bit(dtype):
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
#for dim in [32]:
for dim in [4096]:
for dim in [32]:
errs = []
relerrs = []
max_err = 0