Fixed potential memory leak.
This commit is contained in:
parent
490153b29f
commit
2221f4cee0
|
@ -3561,7 +3561,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
||||||
|
|
||||||
if(row_B < M)
|
if(row_B < M)
|
||||||
{
|
{
|
||||||
if((inner_idx_halved + num_values_8bit) < K)
|
if((inner_idx_halved + num_values_8bit) < (K/2))
|
||||||
{
|
{
|
||||||
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
|
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
|
||||||
}
|
}
|
||||||
|
@ -3569,15 +3569,21 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
||||||
{
|
{
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int j = 0; j < (num_values_8bit); j++)
|
for(int j = 0; j < (num_values_8bit); j++)
|
||||||
if((inner_idx_halved) + j < K)
|
if((inner_idx_halved) + j < (K/2))
|
||||||
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
|
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
|
||||||
else
|
else
|
||||||
local_B_4bit[j] = 0b01110111;
|
local_B_4bit[j] = 0b01110111;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for(int j = 0; j < (num_values_8bit); j++)
|
||||||
|
local_B_4bit[j] = 0b01110111;
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
for(int k = 0; k < num_values_8bit; k++)
|
||||||
{
|
{
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
|
||||||
|
@ -3625,7 +3631,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
||||||
local_C += ((float)local_A[k]*(float)local_B[k]);
|
local_C += ((float)local_A[k]*(float)local_B[k]);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
|
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user