Added debugging functions.
This commit is contained in:
parent
b7f04e2a20
commit
e54d2730fc
|
@ -3297,11 +3297,21 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
#endif
|
||||
}
|
||||
|
||||
|
||||
template <typename T> __device__ void printnonzero(T *A, int num_values)
|
||||
{
|
||||
for(int i = 0; i < num_values; i++)
|
||||
if((float)A[i] != 0.0)
|
||||
printf("%i %f\n", i, (float)A[i]);
|
||||
}
|
||||
|
||||
template __device__ void printnonzero<float>(float *A, int num_values);
|
||||
template __device__ void printnonzero<half>(half *A, int num_values);
|
||||
|
||||
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
#if __CUDA_ARCH__ >= 750
|
||||
using namespace nvcuda;
|
||||
int col_offset = blockIdx.x *32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
|
@ -3469,9 +3479,10 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
if(warp_id == (WARPS-1))
|
||||
wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major);
|
||||
|
||||
printnonzero<T>(smem_A, 32);
|
||||
|
||||
if(col_offset + warp_lane < M)
|
||||
out[col_offset + warp_lane] = smem_A[warp_lane];
|
||||
#endif
|
||||
}
|
||||
|
||||
//#define ROWS 2
|
||||
|
|
|
@ -2414,7 +2414,7 @@ def test_gemm_4bit(dtype):
|
|||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [32]:
|
||||
for dim in [4096]:
|
||||
for dim in [32]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user