diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 216d436..34e552b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3519,7 +3519,7 @@ template __global__ void kgemm_4bit_inference(int M, i out[col_offset + warp_lane] = smem_C[warp_lane]; } -#define num_values_4bit 16 +#define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { @@ -3529,72 +3529,68 @@ template __global__ void kgemm_4bit_inference_naive(in // 4 warps -> 4 loads per iter // 1x128 * 128x4 -> 1x4 outputs typedef cub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage[4]; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/32]; const int warp_idx = threadIdx.x / 32; const int warp_lane = threadIdx.x % 32; - const int row_B = 4*blockIdx.x + warp_idx; + const int row_B = (THREADS/32)*blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit/2; T local_C = T(0); - T quant_map[16]; - #pragma unroll 16 - for(int i = 0; i < 16; i++) - quant_map[i] = nf4_data[i]; - - unsigned char local_B_4bit[num_values_4bit/2]; + unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit]; + T local_A[num_values_4bit]; + __shared__ half quant_map[16*THREADS]; - // need to increase occupancy by splitting the rows, but can be done later + for(int i = 0; i < 16; i++) + quant_map[threadIdx.x + (i*blockDim.x)] = nf4_data[i]; + __syncthreads(); // A: [1, K] // B: [N, K] for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) { - int offset_B = ldb*row_B + (inner_idx/2); - int absidx = (2*offset_B)/blocksize; + int inner_idx_halved = inner_idx/2; + int offset_B = ldb*row_B; + int absidx = ((2*offset_B)+inner_idx)/blocksize; T local_absmax = __ldg(&(absmax[absidx])); - //printf("%f %i %i %i %i %i %i\n", (float)local_absmax, absidx, lda*row_B, K, ldb, row_B, offset_B); - #pragma unroll - for(int k = 0; k < num_values_4bit/2; k++) + if(row_B < M) { - if((inner_idx/2) < K && row_B < M) - local_B_4bit[k] = B[offset_B + k]; + if((inner_idx_halved + num_values_8bit) < K) + { + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } else - local_B_4bit[k] = 0b01110111; + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx/2) + j < K) + local_B_4bit[j] = 0b01110111; + } } - - //if(row_B < M) - //{ - // if((inner_idx/num_values_4bit) < K) - // reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[offset_B/(num_values_4bit/2)]; - // else - // { - // for(int k = 0; k < num_values_4bit/2; k++) - // { - // if((inner_idx/2) < K && row_B < M) - // local_B_4bit[k] = B[offset_B + k]; - // else - // local_B_4bit[k] = 0b01110111; - // } - // } - //} - - - #pragma unroll for(int k = 0; k < num_values_4bit; k++) { - local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; - local_B[k*2+ 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; + local_B[k*2] = quant_map[(local_B_4bit[k] >> 4)*THREADS+threadIdx.x]*local_absmax; + local_B[k*2+ 1] = quant_map[(local_B_4bit[k] & 0x0F)*THREADS+threadIdx.x]*local_absmax; } - //printnonzero(local_B, 4, "B values: "); + if(inner_idx+num_values_4bit) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 1]; + reinterpret_cast(local_A)[2] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 2]; + reinterpret_cast(local_A)[3] = reinterpret_cast(A)[inner_idx/(num_values_8bit/2) + 3]; + } + else + for(int k = 0; k < num_values_4bit; k++) + local_A[k] = A[inner_idx +k]; #pragma unroll for(int k = 0; k < num_values_4bit; k++) - local_C += A[inner_idx + k]*local_B[k]; + local_C += local_A[k]*local_B[k]; } @@ -3773,6 +3769,7 @@ template __global__ void kgemm_4bit_inference(int M, int N, int K, ha template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index ed242c9..c30e979 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -733,6 +733,7 @@ template void gemm_4bit_inference_naive(int m, int n, int k, T * A, { int num_blocks = (m+3)/4; + //int num_blocks = m; cout << num_blocks << endl; //cout << lda << endl; diff --git a/tests/test_functional.py b/tests/test_functional.py index 752dd1d..598b995 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2415,21 +2415,21 @@ def test_gemm_4bit(dtype): #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: #for dim in [4096, 5120, 6656, 8192]: #for dim in [32]: - for dim in [4096]: + for dim in [2*4096]: #for dim in [5120]: #for dim in [6656]: - #for dim in [128]: + #for dim in [4]: errs = [] relerrs = [] max_err = 0 max_relerr = 0 - for i in range(1): + for i in range(100): #A = torch.rand(2, 4092, dtype=dtype, device='cuda') #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') #A = torch.rand(1, 4096, dtype=dtype, device='cuda') #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') - A = torch.randn(1, dim+2, dtype=dtype, device='cuda') - B = torch.randn(4*dim, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim) #B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim) #print('')