From e54d2730fc033489be1ee61dab5ac5e22f798527 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 30 May 2023 20:42:21 -0700 Subject: [PATCH] Added debugging functions. --- csrc/kernels.cu | 15 +++++++++++++-- tests/test_functional.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 7a752cb..ea0be06 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3297,11 +3297,21 @@ template __global__ void gemm_device(int M, #endif } + +template __device__ void printnonzero(T *A, int num_values) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%i %f\n", i, (float)A[i]); +} + +template __device__ void printnonzero(float *A, int num_values); +template __device__ void printnonzero(half *A, int num_values); + __device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { -#if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x *32; const int warp_id = threadIdx.x / 32; @@ -3469,9 +3479,10 @@ template __global__ void kgemm_4bit_inference(int M, i if(warp_id == (WARPS-1)) wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + printnonzero(smem_A, 32); + if(col_offset + warp_lane < M) out[col_offset + warp_lane] = smem_A[warp_lane]; -#endif } //#define ROWS 2 diff --git a/tests/test_functional.py b/tests/test_functional.py index 29b82e6..54ceed5 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2414,7 +2414,7 @@ def test_gemm_4bit(dtype): #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: #for dim in [32]: - for dim in [4096]: + for dim in [32]: errs = [] relerrs = [] max_err = 0