forked from mrq/bitsandbytes-rocm
Fixed rowcol synchronization bug.
This commit is contained in:
parent
c771b3a75a
commit
7d2ecd30c0
|
@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
|
|||
|
||||
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
|
||||
__shared__ int smem_row_nnz_values[TILE_ROWS];
|
||||
//__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
|
||||
|
||||
half local_data[ITEMS_PER_THREAD];
|
||||
float local_data_fp32[ITEMS_PER_THREAD];
|
||||
|
@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
|
|||
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
|
||||
|
||||
// 3. compute row max (per block); store in smem to accumulate full global mem transation
|
||||
__syncthreads();
|
||||
|
||||
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
local_data_fp32[j] = local_data[j];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
|
||||
if(SPARSE_DECOMP)
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue
Block a user