forked from mrq/bitsandbytes-rocm
Merge branch 'patch_merge' into extract_outliers
This commit is contained in:
commit
5737f2b027
14
CHANGELOG.md
14
CHANGELOG.md
|
@ -53,3 +53,17 @@ Bug fixes:
|
||||||
|
|
||||||
Docs:
|
Docs:
|
||||||
- Added instructions how to solve "\_\_fatbinwrap_" errors.
|
- Added instructions how to solve "\_\_fatbinwrap_" errors.
|
||||||
|
|
||||||
|
|
||||||
|
### 0.30.0
|
||||||
|
|
||||||
|
#### 8-bit Inference Update
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt)
|
||||||
|
- Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation
|
||||||
|
- Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations
|
||||||
|
- CPU only build now available (Thank you, @mryab)
|
||||||
|
|
||||||
|
Deprecated:
|
||||||
|
- Pre-compiled release for CUDA 9.2, 10.0, 10.2 no longer available
|
||||||
|
|
37
Makefile
37
Makefile
|
@ -16,7 +16,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
|
||||||
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
|
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
|
||||||
|
|
||||||
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
|
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
|
||||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||||
|
|
||||||
# NVIDIA NVCC compilation flags
|
# NVIDIA NVCC compilation flags
|
||||||
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||||
|
@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
|
||||||
|
|
||||||
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||||
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||||
|
@ -43,31 +42,49 @@ CC_CUDA11x := -gencode arch=compute_75,code=sm_75
|
||||||
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
||||||
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
||||||
|
|
||||||
|
CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
|
||||||
|
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
|
||||||
|
|
||||||
|
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
|
||||||
|
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
||||||
|
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
||||||
|
|
||||||
|
|
||||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
cuda110: $(BUILD_DIR) env
|
cuda110_nomatmul: $(BUILD_DIR) env
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
cuda11x: $(BUILD_DIR) env
|
cuda11x_nomatmul: $(BUILD_DIR) env
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
|
cuda110: $(BUILD_DIR) env
|
||||||
|
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||||
|
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
|
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
|
cuda11x: $(BUILD_DIR) env
|
||||||
|
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||||
|
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
|
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
cpuonly: $(BUILD_DIR) env
|
cpuonly: $(BUILD_DIR) env
|
||||||
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so
|
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so
|
||||||
|
|
||||||
|
|
|
@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
|
||||||
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
|
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
|
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
||||||
shapeA = SA[0]
|
shapeA = SA[0]
|
||||||
shapeB = SB[0]
|
shapeB = SB[0]
|
||||||
dimsA = len(shapeA)
|
dimsA = len(shapeA)
|
||||||
|
@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
||||||
elif dimsA == 3 and out is None:
|
elif dimsA == 3 and out is None:
|
||||||
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
|
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
|
||||||
|
|
||||||
if row_scale is not None: assert row_scale.numel() == out.shape[0]
|
|
||||||
assert dimsB != 3, 'len(B.shape)==3 not supported'
|
assert dimsB != 3, 'len(B.shape)==3 not supported'
|
||||||
assert A.device.type == 'cuda'
|
assert A.device.type == 'cuda'
|
||||||
assert B.device.type == 'cuda'
|
assert B.device.type == 'cuda'
|
||||||
|
@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
||||||
ptrA = get_ptr(A)
|
ptrA = get_ptr(A)
|
||||||
ptrB = get_ptr(B)
|
ptrB = get_ptr(B)
|
||||||
ptrC = get_ptr(out)
|
ptrC = get_ptr(out)
|
||||||
ptrRowScale = get_ptr(row_scale)
|
|
||||||
|
|
||||||
k = shapeA[-1]
|
k = shapeA[-1]
|
||||||
lda = ct.c_int32(m*32)
|
lda = ct.c_int32(m*32)
|
||||||
|
@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
|
||||||
k = ct.c_int32(k)
|
k = ct.c_int32(k)
|
||||||
|
|
||||||
has_error = 0
|
has_error = 0
|
||||||
|
ptrRowScale = get_ptr(None)
|
||||||
if formatB == 'col_turing':
|
if formatB == 'col_turing':
|
||||||
if dtype == torch.int32:
|
if dtype == torch.int32:
|
||||||
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
elif row_scale is None:
|
|
||||||
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
|
||||||
else:
|
else:
|
||||||
has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
elif formatB == 'col_ampere':
|
elif formatB == 'col_ampere':
|
||||||
if dtype == torch.int32:
|
if dtype == torch.int32:
|
||||||
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
elif row_scale is None:
|
|
||||||
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
|
||||||
else:
|
else:
|
||||||
has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
|
|
||||||
if has_error == 1:
|
if has_error == 1:
|
||||||
raise Exception('cublasLt ran into an error!')
|
raise Exception('cublasLt ran into an error!')
|
||||||
|
|
|
@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
|
||||||
|
|
||||||
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
|
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
|
||||||
__shared__ int smem_row_nnz_values[TILE_ROWS];
|
__shared__ int smem_row_nnz_values[TILE_ROWS];
|
||||||
//__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
|
|
||||||
|
|
||||||
half local_data[ITEMS_PER_THREAD];
|
half local_data[ITEMS_PER_THREAD];
|
||||||
float local_data_fp32[ITEMS_PER_THREAD];
|
float local_data_fp32[ITEMS_PER_THREAD];
|
||||||
|
@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
|
||||||
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
|
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
|
||||||
|
|
||||||
// 3. compute row max (per block); store in smem to accumulate full global mem transation
|
// 3. compute row max (per block); store in smem to accumulate full global mem transation
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
|
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
|
||||||
#pragma unroll ITEMS_PER_THREAD
|
#pragma unroll ITEMS_PER_THREAD
|
||||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||||
local_data_fp32[j] = local_data[j];
|
local_data_fp32[j] = local_data[j];
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
|
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
|
||||||
if(SPARSE_DECOMP)
|
if(SPARSE_DECOMP)
|
||||||
{
|
{
|
||||||
|
@ -2166,7 +2166,6 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
||||||
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
|
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
|
||||||
char local_data[ITEMS_PER_THREAD];
|
char local_data[ITEMS_PER_THREAD];
|
||||||
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
||||||
__shared__ typename BlockExchange::TempStorage temp_storage;
|
|
||||||
|
|
||||||
// we load row after row from the base_position
|
// we load row after row from the base_position
|
||||||
// Load data row by row
|
// Load data row by row
|
||||||
|
@ -2446,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
||||||
#define MAX_SPARSE_COUNT 32
|
#define MAX_SPARSE_COUNT 32
|
||||||
#define SMEM_SIZE 8*256
|
#define SMEM_SIZE 8*256
|
||||||
template <typename T, int SPMM_ITEMS, int BITS>
|
template <typename T, int SPMM_ITEMS, int BITS>
|
||||||
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
|
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
|
||||||
{
|
{
|
||||||
|
|
||||||
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
|
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
|
||||||
|
@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
|
||||||
{
|
{
|
||||||
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
|
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
|
||||||
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
|
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
|
||||||
smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
|
smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,7 +107,7 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
|
||||||
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
|
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
|
||||||
|
|
||||||
|
|
||||||
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||||
|
|
||||||
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
|
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
|
||||||
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
|
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
|
||||||
|
|
22
csrc/ops.cu
22
csrc/ops.cu
|
@ -247,6 +247,8 @@ int roundoff(int v, int d) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef NO_CUBLASLT
|
||||||
|
#else
|
||||||
template<int ORDER> cublasLtOrder_t get_order()
|
template<int ORDER> cublasLtOrder_t get_order()
|
||||||
{
|
{
|
||||||
switch(ORDER)
|
switch(ORDER)
|
||||||
|
@ -266,7 +268,11 @@ template<int ORDER> cublasLtOrder_t get_order()
|
||||||
case COL_AMPERE:
|
case COL_AMPERE:
|
||||||
return CUBLASLT_ORDER_COL32_2R_4R4;
|
return CUBLASLT_ORDER_COL32_2R_4R4;
|
||||||
break;
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return CUBLASLT_ORDER_ROW;
|
||||||
}
|
}
|
||||||
|
|
||||||
template cublasLtOrder_t get_order<ROW>();
|
template cublasLtOrder_t get_order<ROW>();
|
||||||
|
@ -274,6 +280,7 @@ template cublasLtOrder_t get_order<COL>();
|
||||||
template cublasLtOrder_t get_order<COL32>();
|
template cublasLtOrder_t get_order<COL32>();
|
||||||
template cublasLtOrder_t get_order<COL_TURING>();
|
template cublasLtOrder_t get_order<COL_TURING>();
|
||||||
template cublasLtOrder_t get_order<COL_AMPERE>();
|
template cublasLtOrder_t get_order<COL_AMPERE>();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
template<int ORDER> int get_leading_dim(int dim1, int dim2)
|
template<int ORDER> int get_leading_dim(int dim1, int dim2)
|
||||||
|
@ -297,6 +304,9 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
|
||||||
// 32*32 tiles
|
// 32*32 tiles
|
||||||
return 32*roundoff(dim1, 32);
|
return 32*roundoff(dim1, 32);
|
||||||
break;
|
break;
|
||||||
|
default:
|
||||||
|
return 0;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -306,7 +316,8 @@ template int get_leading_dim<COL32>(int dim1, int dim2);
|
||||||
|
|
||||||
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
|
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
|
||||||
{
|
{
|
||||||
|
#ifdef NO_CUBLASLT
|
||||||
|
#else
|
||||||
cublasLtOrder_t orderA = get_order<SRC>();
|
cublasLtOrder_t orderA = get_order<SRC>();
|
||||||
cublasLtOrder_t orderOut = get_order<TARGET>();
|
cublasLtOrder_t orderOut = get_order<TARGET>();
|
||||||
int ldA = get_leading_dim<SRC>(dim1, dim2);
|
int ldA = get_leading_dim<SRC>(dim1, dim2);
|
||||||
|
@ -345,6 +356,7 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
|
||||||
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
|
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
|
||||||
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
|
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
|
||||||
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
|
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||||
|
@ -358,6 +370,9 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
|
||||||
|
|
||||||
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||||
{
|
{
|
||||||
|
#ifdef NO_CUBLASLT
|
||||||
|
return 0;
|
||||||
|
#else
|
||||||
int has_error = 0;
|
int has_error = 0;
|
||||||
cublasLtMatmulDesc_t matmulDesc = NULL;
|
cublasLtMatmulDesc_t matmulDesc = NULL;
|
||||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||||
|
@ -412,6 +427,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
|
||||||
printf("error detected");
|
printf("error detected");
|
||||||
|
|
||||||
return has_error;
|
return has_error;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
int fill_up_to_nearest_multiple(int value, int multiple)
|
int fill_up_to_nearest_multiple(int value, int multiple)
|
||||||
|
@ -523,6 +539,9 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
||||||
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
|
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
|
||||||
{
|
{
|
||||||
|
|
||||||
|
#ifdef NO_CUBLASLT
|
||||||
|
#else
|
||||||
|
|
||||||
cusparseSpMatDescr_t descA;
|
cusparseSpMatDescr_t descA;
|
||||||
cusparseDnMatDescr_t descB, descC;
|
cusparseDnMatDescr_t descB, descC;
|
||||||
|
|
||||||
|
@ -569,6 +588,7 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
|
||||||
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
|
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
|
||||||
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
|
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
|
||||||
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
|
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||||
|
|
|
@ -82,7 +82,6 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un
|
||||||
|
|
||||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||||
#endif
|
|
||||||
|
|
||||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||||
|
@ -132,10 +131,11 @@ void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_r
|
||||||
|
|
||||||
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||||
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
|
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
|
||||||
|
#endif
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
{
|
{
|
||||||
#if BUILD_CUDA
|
#if BUILD_CUDA
|
||||||
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
|
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
|
||||||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||||
|
@ -231,7 +231,7 @@ extern "C"
|
||||||
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||||
|
|
||||||
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||||
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
{ return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||||
|
|
||||||
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||||
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
|
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
|
||||||
|
|
77
cuda_install.sh
Normal file
77
cuda_install.sh
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux
|
||||||
|
URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux
|
||||||
|
URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run
|
||||||
|
URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run
|
||||||
|
URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run
|
||||||
|
URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
|
||||||
|
URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run
|
||||||
|
URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run
|
||||||
|
URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run
|
||||||
|
URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run
|
||||||
|
URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run
|
||||||
|
URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_VERSION=$1
|
||||||
|
BASE_PATH=$2
|
||||||
|
|
||||||
|
if [[ -n "$CUDA_VERSION" ]]; then
|
||||||
|
if [[ "$CUDA_VERSION" -eq "92" ]]; then
|
||||||
|
URL=$URL92
|
||||||
|
FOLDER=cuda-9.2
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "100" ]]; then
|
||||||
|
URL=$URL100
|
||||||
|
FOLDER=cuda-10.0
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "101" ]]; then
|
||||||
|
URL=$URL101
|
||||||
|
FOLDER=cuda-10.1
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "102" ]]; then
|
||||||
|
URL=$URL102
|
||||||
|
FOLDER=cuda-10.2
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "110" ]]; then
|
||||||
|
URL=$URL110
|
||||||
|
FOLDER=cuda-11.0
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "111" ]]; then
|
||||||
|
URL=$URL111
|
||||||
|
FOLDER=cuda-11.1
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "112" ]]; then
|
||||||
|
URL=$URL112
|
||||||
|
FOLDER=cuda-11.2
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "113" ]]; then
|
||||||
|
URL=$URL113
|
||||||
|
FOLDER=cuda-11.3
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "114" ]]; then
|
||||||
|
URL=$URL114
|
||||||
|
FOLDER=cuda-11.4
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "115" ]]; then
|
||||||
|
URL=$URL115
|
||||||
|
FOLDER=cuda-11.5
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "116" ]]; then
|
||||||
|
URL=$URL116
|
||||||
|
FOLDER=cuda-11.6
|
||||||
|
elif [[ "$CUDA_VERSION" -eq "117" ]]; then
|
||||||
|
URL=$URL117
|
||||||
|
FOLDER=cuda-11.7
|
||||||
|
else
|
||||||
|
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
FILE=$(basename $URL)
|
||||||
|
|
||||||
|
if [[ -n "$CUDA_VERSION" ]]; then
|
||||||
|
echo $URL
|
||||||
|
echo $FILE
|
||||||
|
wget $URL
|
||||||
|
bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=$BASE_PATH/lib --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
|
||||||
|
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc
|
||||||
|
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc
|
||||||
|
source ~/.bashrc
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,86 +1,261 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
BASE_PATH=$1
|
||||||
|
|
||||||
module unload cuda
|
module unload cuda
|
||||||
module unload gcc
|
module unload gcc
|
||||||
|
|
||||||
rm -rf dist build
|
#rm -rf dist build
|
||||||
make clean
|
#make clean
|
||||||
make cleaneggs
|
#make cleaneggs
|
||||||
module load cuda/9.2
|
#export CUDA_HOME=
|
||||||
module load gcc/7.3.0
|
#make cpuonly
|
||||||
CUDA_HOME=/public/apps/cuda/9.2
|
#
|
||||||
make
|
#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
CUDA_VERSION=92 python -m build
|
# # Control will enter here if $DIRECTORY doesn't exist.
|
||||||
python -m twine upload dist/* --verbose
|
# echo "Compilation unsuccessul!" 1>&2
|
||||||
module unload cuda
|
# exit 64
|
||||||
|
#fi
|
||||||
|
#CUDA_VERSION=cpu python -m build
|
||||||
|
#python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
rm -rf dist build
|
rm -rf dist build
|
||||||
make clean
|
make clean
|
||||||
make cleaneggs
|
make cleaneggs
|
||||||
module load cuda/10.0
|
export CUDA_HOME=$BASE_PATH/cuda-11.0
|
||||||
CUDA_HOME=/public/apps/cuda/10.0
|
|
||||||
make cuda10x
|
|
||||||
CUDA_VERSION=100 python -m build
|
|
||||||
python -m twine upload dist/* --verbose
|
|
||||||
module unload cuda
|
|
||||||
module unload gcc
|
|
||||||
module load gcc/8.4
|
|
||||||
|
|
||||||
rm -rf dist build
|
|
||||||
make clean
|
|
||||||
make cleaneggs
|
|
||||||
module load cuda/10.1
|
|
||||||
CUDA_HOME=/public/apps/cuda/10.1
|
|
||||||
make cuda10x
|
|
||||||
CUDA_VERSION=101 python -m build
|
|
||||||
python -m twine upload dist/* --verbose
|
|
||||||
module unload cuda
|
|
||||||
|
|
||||||
rm -rf dist build
|
|
||||||
make clean
|
|
||||||
make cleaneggs
|
|
||||||
module load cuda/10.2
|
|
||||||
CUDA_HOME=/public/apps/cuda/10.2/
|
|
||||||
make cuda10x
|
|
||||||
CUDA_VERSION=102 python -m build
|
|
||||||
python -m twine upload dist/* --verbose
|
|
||||||
module unload cuda
|
|
||||||
|
|
||||||
|
|
||||||
rm -rf dist build
|
|
||||||
make clean
|
|
||||||
make cleaneggs
|
|
||||||
module load cuda/11.0
|
|
||||||
CUDA_HOME=/public/apps/cuda/11.0
|
|
||||||
make cuda110
|
make cuda110
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
CUDA_VERSION=110 python -m build
|
CUDA_VERSION=110 python -m build
|
||||||
python -m twine upload dist/* --verbose
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
module unload cuda
|
|
||||||
|
|
||||||
rm -rf dist build
|
rm -rf dist build
|
||||||
make clean
|
make clean
|
||||||
make cleaneggs
|
make cleaneggs
|
||||||
module load cuda/11.1
|
export CUDA_HOME=$BASE_PATH/cuda-11.1
|
||||||
CUDA_HOME=/public/apps/cuda/11.1
|
|
||||||
make cuda11x
|
make cuda11x
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
CUDA_VERSION=111 python -m build
|
CUDA_VERSION=111 python -m build
|
||||||
python -m twine upload dist/* --verbose
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
module unload cuda
|
|
||||||
|
|
||||||
rm -rf dist build
|
rm -rf dist build
|
||||||
make clean
|
make clean
|
||||||
make cleaneggs
|
make cleaneggs
|
||||||
module load cuda/11.2
|
export CUDA_HOME=$BASE_PATH/cuda-11.2
|
||||||
CUDA_HOME=/public/apps/cuda/11.2
|
|
||||||
make cuda11x
|
make cuda11x
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
CUDA_VERSION=112 python -m build
|
CUDA_VERSION=112 python -m build
|
||||||
python -m twine upload dist/* --verbose
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
module unload cuda
|
|
||||||
|
|
||||||
rm -rf dist build
|
rm -rf dist build
|
||||||
make clean
|
make clean
|
||||||
make cleaneggs
|
make cleaneggs
|
||||||
CUDA_HOME=/private/home/timdettmers/git/autoswap/local/cuda-11.3 make cuda11x
|
export CUDA_HOME=$BASE_PATH/cuda-11.3
|
||||||
|
make cuda11x
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
CUDA_VERSION=113 python -m build
|
CUDA_VERSION=113 python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.4
|
||||||
|
make cuda11x
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=114 python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.5
|
||||||
|
make cuda11x
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=115 python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
#rm -rf dist build
|
||||||
|
#make clean
|
||||||
|
#make cleaneggs
|
||||||
|
#export CUDA_HOME=$BASE_PATH/cuda-11.6
|
||||||
|
#
|
||||||
|
#make cuda11x
|
||||||
|
#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# # Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
# echo "Compilation unsuccessul!" 1>&2
|
||||||
|
# exit 64
|
||||||
|
#fi
|
||||||
|
#CUDA_VERSION=116 python -m build
|
||||||
|
#python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
#
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.7
|
||||||
|
make cuda11x
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=117 python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-10.2
|
||||||
|
make cuda10x_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=102-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.0
|
||||||
|
make cuda110_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=110-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.1
|
||||||
|
make cuda11x_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=111-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.2
|
||||||
|
make cuda11x_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=112-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.3
|
||||||
|
make cuda11x_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=113-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.4
|
||||||
|
make cuda11x_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=114-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.5
|
||||||
|
make cuda11x_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=115-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.6
|
||||||
|
|
||||||
|
make cuda11x_nomatmul
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=116-nomatmul python -m build
|
||||||
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
export CUDA_HOME=$BASE_PATH/cuda-11.7
|
||||||
|
make cuda11x_nomatmul
|
||||||
|
|
||||||
|
if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
|
||||||
|
# Control will enter here if $DIRECTORY doesn't exist.
|
||||||
|
echo "Compilation unsuccessul!" 1>&2
|
||||||
|
exit 64
|
||||||
|
fi
|
||||||
|
CUDA_VERSION=117-nomatmul python -m build
|
||||||
python -m twine upload dist/* --verbose
|
python -m twine upload dist/* --verbose
|
||||||
module unload cuda
|
python -m twine upload dist/* --verbose --repository testpypi
|
||||||
|
|
90
quicktest.py
Normal file
90
quicktest.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
import torch
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import bitsandbytes.functional as F
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
|
||||||
|
k = 25
|
||||||
|
for i in range(k):
|
||||||
|
if dims == 2:
|
||||||
|
A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
|
||||||
|
elif dims == 3:
|
||||||
|
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||||||
|
B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
|
||||||
|
C1 = torch.matmul(A.float(), B.t().float())
|
||||||
|
|
||||||
|
A2, SA = F.transform(A, 'col32')
|
||||||
|
B2, SB = F.transform(B, 'colx')
|
||||||
|
if dims == 2:
|
||||||
|
C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||||||
|
else:
|
||||||
|
C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||||||
|
F.igemmlt(A2, B2, C2, SA, SB, SC)
|
||||||
|
C3, S = F.transform(C2, 'row', state=SC)
|
||||||
|
#torch.testing.assert_allclose(C1, C3.float())
|
||||||
|
#print(C1)
|
||||||
|
#print(C2)
|
||||||
|
#print(C3)
|
||||||
|
allclose = torch.allclose(C1, C3.float())
|
||||||
|
if allclose:
|
||||||
|
print(C1)
|
||||||
|
print(C2)
|
||||||
|
print(C3)
|
||||||
|
|
||||||
|
## transposed
|
||||||
|
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
|
||||||
|
#if dims == 2:
|
||||||
|
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
|
||||||
|
# C1 = torch.matmul(A.float(), B.float().t())
|
||||||
|
#elif dims == 3:
|
||||||
|
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||||||
|
# C1 = torch.matmul(B.float(), A.t().float())
|
||||||
|
# C1 = C1.permute([2, 0, 1])
|
||||||
|
|
||||||
|
#A2, SA = F.transform(A, 'col32')
|
||||||
|
#B2, SB = F.transform(B, 'colx')
|
||||||
|
#if dims == 2:
|
||||||
|
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||||||
|
#else:
|
||||||
|
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
|
||||||
|
# state = (C2.shape, 'row', A.shape[0])
|
||||||
|
# C2, SC = F.transform(C2, 'col32', state=state)
|
||||||
|
#F.igemmlt(A2, B2, C2, SA, SB, SC)
|
||||||
|
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
|
||||||
|
#torch.testing.assert_allclose(C1, C3.float())
|
||||||
|
|
||||||
|
## weight update
|
||||||
|
#if dims == 3:
|
||||||
|
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||||||
|
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
|
||||||
|
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
|
||||||
|
|
||||||
|
# A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
|
||||||
|
# B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
|
||||||
|
# C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
|
||||||
|
# C2, SC = F.transform(C2, 'col32')
|
||||||
|
# F.igemmlt(B2, A2, C2, SB, SA, SC)
|
||||||
|
# C3, S = F.transform(C2, 'row', state=SC)
|
||||||
|
# torch.testing.assert_allclose(C1, C3.float())
|
||||||
|
|
||||||
|
|
||||||
|
dims = (2, 3)
|
||||||
|
ldb = [0]
|
||||||
|
|
||||||
|
n = 2
|
||||||
|
dim1 = torch.randint(1,256, size=(n,)).tolist()
|
||||||
|
dim2 = torch.randint(32,512, size=(n,)).tolist()
|
||||||
|
dim3 = torch.randint(32,1024, size=(n,)).tolist()
|
||||||
|
dim4 = torch.randint(32,1024, size=(n,)).tolist()
|
||||||
|
values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
|
||||||
|
|
||||||
|
for ldb in range(32, 4096, 32):
|
||||||
|
#for ldb in [None]:
|
||||||
|
val = test_igemmlt(2, 2, 2, 2, 2, ldb)
|
||||||
|
if val:
|
||||||
|
print(val, ldb)
|
||||||
|
else:
|
||||||
|
print('nope', ldb)
|
||||||
|
#for val in values:
|
||||||
|
#test_igemmlt(*val)
|
7
setup.py
7
setup.py
|
@ -11,13 +11,14 @@ def read(fname):
|
||||||
|
|
||||||
|
|
||||||
version = os.getenv("CUDA_VERSION", "cpu")
|
version = os.getenv("CUDA_VERSION", "cpu")
|
||||||
|
prefix = '' if version == 'cpu' else 'cuda'
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="bitsandbytes",
|
name=f"bitsandbytes-{prefix}{version}",
|
||||||
version=f"0.26.0+{version}",
|
version=f"0.30.0",
|
||||||
author="Tim Dettmers",
|
author="Tim Dettmers",
|
||||||
author_email="dettmers@cs.washington.edu",
|
author_email="dettmers@cs.washington.edu",
|
||||||
description="8-bit optimizers and quantization routines.",
|
description="8-bit optimizers and matrix multiplication routines.",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
keywords="gpu optimizers optimization 8-bit quantization compression",
|
keywords="gpu optimizers optimization 8-bit quantization compression",
|
||||||
url="http://packages.python.org/bitsandbytes",
|
url="http://packages.python.org/bitsandbytes",
|
||||||
|
|
|
@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist()
|
||||||
values = list(zip(dim1, dim4, inner))
|
values = list(zip(dim1, dim4, inner))
|
||||||
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
||||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||||
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||||
def test_igemmlt_row_scale(dim1, dim4, inner):
|
def test_igemmlt_row_scale(dim1, dim4, inner):
|
||||||
formatB = F.get_special_format_str()
|
formatB = F.get_special_format_str()
|
||||||
err1, err2, err3 = [], [], []
|
err1, err2, err3 = [], [], []
|
||||||
|
@ -1064,6 +1065,7 @@ dim4 = [12288, 4096]
|
||||||
values = list(zip(dim1, dim4, inner))
|
values = list(zip(dim1, dim4, inner))
|
||||||
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
|
||||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||||
|
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||||
def test_row_scale_bench(dim1, dim4, inner):
|
def test_row_scale_bench(dim1, dim4, inner):
|
||||||
err1, err2, err3 = [], [], []
|
err1, err2, err3 = [], [], []
|
||||||
relerr1, relerr2 = [], []
|
relerr1, relerr2 = [], []
|
||||||
|
@ -1183,6 +1185,7 @@ def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
|
||||||
|
|
||||||
def test_overflow():
|
def test_overflow():
|
||||||
formatB = F.get_special_format_str()
|
formatB = F.get_special_format_str()
|
||||||
|
print(formatB)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
|
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
|
||||||
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
|
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
|
||||||
|
|
Loading…
Reference in New Issue
Block a user