Fixed rowcol synchronization bug.

This commit is contained in:
Tim Dettmers 2022-07-22 15:21:37 -07:00
parent c771b3a75a
commit 7d2ecd30c0

View File

@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
__shared__ int smem_row_nnz_values[TILE_ROWS]; __shared__ int smem_row_nnz_values[TILE_ROWS];
//__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
half local_data[ITEMS_PER_THREAD]; half local_data[ITEMS_PER_THREAD];
float local_data_fp32[ITEMS_PER_THREAD]; float local_data_fp32[ITEMS_PER_THREAD];
@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation // 3. compute row max (per block); store in smem to accumulate full global mem transation
__syncthreads();
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD #pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++) for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j]; local_data_fp32[j] = local_data[j];
__syncthreads();
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
if(SPARSE_DECOMP) if(SPARSE_DECOMP)
{ {