Fixed makefile; fixed Ampere igemmlt_8 bug.
This commit is contained in:
parent
7d2ecd30c0
commit
8b1fd32e3e
11
CHANGELOG.md
11
CHANGELOG.md
|
@ -53,3 +53,14 @@ Bug fixes:
|
|||
|
||||
Docs:
|
||||
- Added instructions how to solve "\_\_fatbinwrap_" errors.
|
||||
|
||||
|
||||
### 0.30.0
|
||||
|
||||
#### 8-bit Inference Update
|
||||
|
||||
Features:
|
||||
- Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt)
|
||||
- Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation
|
||||
- Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations
|
||||
- CPU only build now available (Thank you, @mryab)
|
||||
|
|
2
Makefile
2
Makefile
|
@ -16,7 +16,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
|
|||
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
|
||||
|
||||
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
|
||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
|
||||
# NVIDIA NVCC compilation flags
|
||||
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
|
|
|
@ -228,7 +228,7 @@ extern "C"
|
|||
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
{ return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
|
|
38
cuda_install_111.sh
Normal file
38
cuda_install_111.sh
Normal file
|
@ -0,0 +1,38 @@
|
|||
FILE115=:cuda_11.5.1_495.29.05_linux.run
|
||||
FILE111=:cuda_11.1.1_455.32.00_linux.run
|
||||
URL115=:https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run
|
||||
URL111=:https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
|
||||
|
||||
|
||||
CUDA_VERSION=$1
|
||||
|
||||
if [[ -n "$CUDA_VERSION" ]]; then
|
||||
if [[ "$CUDA_VERSION" -eq "111" ]]; then
|
||||
FILE=cuda_11.1.1_455.32.00_linux.run
|
||||
URL=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
|
||||
FOLDER=cuda-11.1
|
||||
elif [[ "$CUDA_VERSION" -eq "115" ]]; then
|
||||
FILE=cuda_11.5.1_495.29.05_linux.run
|
||||
URL=https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run
|
||||
FOLDER=cuda-11.5
|
||||
else
|
||||
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
|
||||
fi
|
||||
else
|
||||
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
|
||||
fi
|
||||
|
||||
if [[ -n "$CUDA_VERSION" ]]; then
|
||||
echo $URL
|
||||
echo $FILE
|
||||
wget $URL
|
||||
bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=~/local/lib --toolkitpath=~/local/$FOLDER/ --toolkit --silent
|
||||
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/local/$FOLDER/lib64/" >> ~/.bashrc
|
||||
echo "export PATH=$PATH:~/local/$FOLDER/bin/" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
else
|
||||
echo ""
|
||||
fi
|
||||
|
||||
|
||||
|
90
quicktest.py
Normal file
90
quicktest.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
import torch
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
from itertools import product
|
||||
|
||||
def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
|
||||
k = 25
|
||||
for i in range(k):
|
||||
if dims == 2:
|
||||
A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
|
||||
elif dims == 3:
|
||||
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||||
B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
|
||||
C1 = torch.matmul(A.float(), B.t().float())
|
||||
|
||||
A2, SA = F.transform(A, 'col32')
|
||||
B2, SB = F.transform(B, 'colx')
|
||||
if dims == 2:
|
||||
C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||||
else:
|
||||
C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||||
F.igemmlt(A2, B2, C2, SA, SB, SC)
|
||||
C3, S = F.transform(C2, 'row', state=SC)
|
||||
#torch.testing.assert_allclose(C1, C3.float())
|
||||
#print(C1)
|
||||
#print(C2)
|
||||
#print(C3)
|
||||
allclose = torch.allclose(C1, C3.float())
|
||||
if allclose:
|
||||
print(C1)
|
||||
print(C2)
|
||||
print(C3)
|
||||
|
||||
## transposed
|
||||
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
|
||||
#if dims == 2:
|
||||
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
|
||||
# C1 = torch.matmul(A.float(), B.float().t())
|
||||
#elif dims == 3:
|
||||
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||||
# C1 = torch.matmul(B.float(), A.t().float())
|
||||
# C1 = C1.permute([2, 0, 1])
|
||||
|
||||
#A2, SA = F.transform(A, 'col32')
|
||||
#B2, SB = F.transform(B, 'colx')
|
||||
#if dims == 2:
|
||||
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
|
||||
#else:
|
||||
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
|
||||
# state = (C2.shape, 'row', A.shape[0])
|
||||
# C2, SC = F.transform(C2, 'col32', state=state)
|
||||
#F.igemmlt(A2, B2, C2, SA, SB, SC)
|
||||
#C3, S = F.transform(C2, 'row', state=SC, ld=[0])
|
||||
#torch.testing.assert_allclose(C1, C3.float())
|
||||
|
||||
## weight update
|
||||
#if dims == 3:
|
||||
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
|
||||
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
|
||||
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
|
||||
|
||||
# A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
|
||||
# B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
|
||||
# C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
|
||||
# C2, SC = F.transform(C2, 'col32')
|
||||
# F.igemmlt(B2, A2, C2, SB, SA, SC)
|
||||
# C3, S = F.transform(C2, 'row', state=SC)
|
||||
# torch.testing.assert_allclose(C1, C3.float())
|
||||
|
||||
|
||||
dims = (2, 3)
|
||||
ldb = [0]
|
||||
|
||||
n = 2
|
||||
dim1 = torch.randint(1,256, size=(n,)).tolist()
|
||||
dim2 = torch.randint(32,512, size=(n,)).tolist()
|
||||
dim3 = torch.randint(32,1024, size=(n,)).tolist()
|
||||
dim4 = torch.randint(32,1024, size=(n,)).tolist()
|
||||
values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
|
||||
|
||||
for ldb in range(32, 4096, 32):
|
||||
#for ldb in [None]:
|
||||
val = test_igemmlt(2, 2, 2, 2, 2, ldb)
|
||||
if val:
|
||||
print(val, ldb)
|
||||
else:
|
||||
print('nope', ldb)
|
||||
#for val in values:
|
||||
#test_igemmlt(*val)
|
|
@ -1183,6 +1183,7 @@ def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
|
|||
|
||||
def test_overflow():
|
||||
formatB = F.get_special_format_str()
|
||||
print(formatB)
|
||||
for i in range(2):
|
||||
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
|
||||
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
|
||||
|
|
Loading…
Reference in New Issue
Block a user