diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 0a845d4..dd1f6f2 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3528,29 +3528,26 @@ template __global__ void kgemm_4bit_inference_naive(in // 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1 // 4 warps -> 4 loads per iter // 1x128 * 128x4 -> 1x4 outputs - typedef cub::WarpReduce WarpReduce; + //typedef cub::WarpReduce WarpReduce; + typedef cub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; const int warp_idx = threadIdx.x / 32; const int warp_lane = threadIdx.x % 32; const int row_B = (THREADS/32)*blockIdx.x + warp_idx; const int num_values_8bit = num_values_4bit/2; - T local_C = T(0); - - T lane_quant_value = nf4_data[warp_lane % 16]; + //T local_C = T(0.0f); + float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit]; T local_A[num_values_4bit]; - __shared__ T quant_map[16*THREADS]; - __shared__ T quant_map2[16]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); - //for(int i = 0; i < 16; i++) - // quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i]; - //__syncthreads(); - - for(int i = 0; i < 16; i++) - quant_map2[i] = nf4_data[i]; + for(int i = threadIdx.x; i < 16; i++) + quant_map[i] = nf4_data[i]; + __syncthreads(); // A: [1, K] // B: [N, K] @@ -3559,7 +3556,7 @@ template __global__ void kgemm_4bit_inference_naive(in int inner_idx_halved = inner_idx/2; int offset_B = ldb*row_B; int absidx = ((2*offset_B)+inner_idx)/blocksize; - T local_absmax = __ldg(&(absmax[absidx])); + local_absmax = __ldg(&(absmax[absidx])); if(row_B < M) { @@ -3576,25 +3573,11 @@ template __global__ void kgemm_4bit_inference_naive(in } } - if(inner_idx+(num_values_4bit*32) < K) + #pragma unroll + for(int k = 0; k < num_values_4bit; k++) { - // full warp is running - #pragma unroll - for(int k = 0; k < num_values_4bit; k++) - { - local_B[k*2] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] >> 4)*local_absmax; - local_B[k*2 + 1] = __shfl_sync(0xffffffff, lane_quant_value, local_B_4bit[k] & 0x0F)*local_absmax; - } - } - else - { - // part of the warp exited already - #pragma unroll - for(int k = 0; k < num_values_4bit; k++) - { - local_B[k*2] = quant_map2[(local_B_4bit[k] >> 4)]*local_absmax; - local_B[k*2 + 1] = quant_map2[(local_B_4bit[k] & 0x0F)]*local_absmax; - } + local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; } if(inner_idx+num_values_4bit) @@ -3603,6 +3586,7 @@ template __global__ void kgemm_4bit_inference_naive(in reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 1]; reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 2]; reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 3]; + } else for(int k = 0; k < num_values_4bit; k++) @@ -3610,14 +3594,14 @@ template __global__ void kgemm_4bit_inference_naive(in #pragma unroll for(int k = 0; k < num_values_4bit; k++) - local_C += local_A[k]*local_B[k]; + local_C += (float)(local_A[k]*local_B[k]); } local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); if(row_B < M && warp_lane == 0) - out[row_B] = local_C; + out[row_B] = T(local_C); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 70d5515..552ccaa 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2420,7 +2420,7 @@ def test_cutlass3_gemm(dtype): def test_gemm_4bit(dtype): print('') #for dim in [64, 128, 256, 512, 1024, 2048, 4096]: - for dim in [4096]: + for dim in [4*1024]: errs = [] relerrs = [] max_err = 0 @@ -2485,10 +2485,10 @@ def test_gemm_4bit(dtype): #print(dim, sum(errs)/len(errs)/math.sqrt(dim)) #print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) #print(dim, (max_err.item(), max_relerr.item())) - #print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) - #print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) - #assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 - #assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 + print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) + print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) + assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed():