Fixed potential memory leak.

This commit is contained in:
Tim Dettmers 2023-07-10 13:57:44 -07:00
parent 490153b29f
commit 2221f4cee0

View File

@ -3561,7 +3561,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
if(row_B < M) if(row_B < M)
{ {
if((inner_idx_halved + num_values_8bit) < K) if((inner_idx_halved + num_values_8bit) < (K/2))
{ {
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
} }
@ -3569,15 +3569,21 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{ {
#pragma unroll #pragma unroll
for(int j = 0; j < (num_values_8bit); j++) for(int j = 0; j < (num_values_8bit); j++)
if((inner_idx_halved) + j < K) if((inner_idx_halved) + j < (K/2))
local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
else else
local_B_4bit[j] = 0b01110111; local_B_4bit[j] = 0b01110111;
} }
} }
else
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
local_B_4bit[j] = 0b01110111;
}
#pragma unroll #pragma unroll
for(int k = 0; k < num_values_4bit; k++) for(int k = 0; k < num_values_8bit; k++)
{ {
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
@ -3625,7 +3631,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
local_C += ((float)local_A[k]*(float)local_B[k]); local_C += ((float)local_A[k]*(float)local_B[k]);
#endif #endif
} }
} }
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);