From 306f6b2362a8430bb407715ee5172a24893bad0f Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 10 Jul 2023 14:24:33 -0700 Subject: [PATCH] Fixed accidential deletion of limits in kernel. --- csrc/kernels.cu | 10 ++++++++-- tests/test_functional.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 407360e..902d759 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3595,7 +3595,7 @@ template __global__ void kgemm_4bit_inferenc #endif } - if(inner_idx+num_values_4bit) + if(inner_idx+num_values_4bit < K) { if(BITS==16) { @@ -3619,11 +3619,17 @@ template __global__ void kgemm_4bit_inferenc } else for(int k = 0; k < num_values_4bit; k++) - local_A[k] = A[inner_idx +k]; + if(inner_idx + k < K) + local_A[k] = A[inner_idx + k]; + else + local_A[k] = T(0.0f); + #pragma unroll for(int k = 0; k < num_values_4bit; k++) { + if((float)local_A[k] < -10.0f || (float)local_B[k] < -10.0f || local_C > 10.0f) + printf("%f %f = %f\n", (float)local_A[k], (float)local_B[k], local_C); #if __CUDA_ARCH__ >= 800 local_C += (float)(local_A[k]*local_B[k]); #else diff --git a/tests/test_functional.py b/tests/test_functional.py index 34552cb..e80eed3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2378,8 +2378,8 @@ def test_gemv_4bit(dtype, storage_type, double_quant): #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - #B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim) #print('') #print(A) @@ -2432,13 +2432,18 @@ def test_gemv_4bit(dtype, storage_type, double_quant): #print(dim, (max_err.item(), max_relerr.item())) print(C1.flatten()[-20:]) print(C2.flatten()[-20:]) - print(C3.flatten()[-20:]) + #print(C1.flatten()) + #print(C2.flatten()) + #print(C3.flatten()[-20:]) print(sum(errs)/len(errs)/math.sqrt(dim) , dim) print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim) if dtype == torch.float16: assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5 assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005 - else: + elif dtype == torch.float32: + assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-8 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 1e-8 + elif dtype == torch.bfloat16: assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4 assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003