Some initial code. Needs to be tested.

This commit is contained in:
Tim Dettmers 2022-08-23 13:59:34 -07:00
parent 9d60b3c527
commit 7e0fb655e1
5 changed files with 42 additions and 37 deletions

View File

@ -17,6 +17,7 @@ evaluation:
""" """
import ctypes import ctypes
import torch
from pathlib import Path from pathlib import Path
from ..utils import execute_and_return from ..utils import execute_and_return
@ -28,7 +29,7 @@ def check_cuda_result(cuda, result_val):
if result_val != 0: if result_val != 0:
error_str = ctypes.c_char_p() error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
raise Exception(f"CUDA exception! Error code: {error_str.value.decode()}") print(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path): def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
@ -57,7 +58,7 @@ def get_cuda_lib_handle():
cuda = ctypes.CDLL("libcuda.so") cuda = ctypes.CDLL("libcuda.so")
except OSError: except OSError:
# TODO: shouldn't we error or at least warn here? # TODO: shouldn't we error or at least warn here?
raise Exception('CUDA SETUP: ERROR! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None return None
check_cuda_result(cuda, cuda.cuInit(0)) check_cuda_result(cuda, cuda.cuInit(0))
@ -119,6 +120,10 @@ def evaluate_cuda_setup():
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link') print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80) print('='*80)
binary_name = "libbitsandbytes_cpu.so" binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name
cudart_path = determine_cuda_runtime_lib_path() cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None: if cudart_path is None:
print( print(

View File

@ -1686,11 +1686,10 @@ def double_quant(
def get_special_format_str(): def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turning'
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
if major < 7: if major < 7:
print( print(f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!")
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
)
assert major >= 7 assert major >= 7
if major == 7: return 'col_turing' if major == 7: return 'col_turing'

View File

@ -5,7 +5,6 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .adam import Adam, Adam8bit, Adam32bit from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit

View File

@ -40,6 +40,7 @@ names = [
ids=names, ids=names,
) )
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
if dim2 > 0: if dim2 > 0:
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
@ -306,6 +307,7 @@ def test_matmullt(
has_fp16_weights, has_fp16_weights,
has_bias has_bias
): ):
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")

View File

@ -1813,16 +1813,16 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size = 1 batch_size = 1
seqdim = 2048 seqdim = 1
values = [] values = []
values.append((batch_size, seqdim, 768, 4 * 768)) #values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024)) # values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536)) # values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048)) # values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560)) # values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096)) # values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288)) values.append((batch_size, seqdim, 12288, 4*12288))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
] ]
@ -1830,6 +1830,7 @@ names = [
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden): def test_bench_matmul(batch, seq, model, hidden):
iters = 128
formatB = F.get_special_format_str() formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half() A = torch.randn(batch, seq, model, device="cuda").half()
@ -1848,28 +1849,33 @@ def test_bench_matmul(batch, seq, model, hidden):
linearMixedBit.eval() linearMixedBit.eval()
# warmup # warmup
for i in range(100): for i in range(iters):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print("") print("")
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
torch.matmul(A, B.t()) torch.matmul(A, B.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
) )
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
bnb.matmul(A, B) bnb.matmul(A, B)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
) torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize()
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
C32A, SA = F.transform(CA, "col32") C32A, SA = F.transform(CA, "col32")
@ -1877,18 +1883,16 @@ def test_bench_matmul(batch, seq, model, hidden):
CxB, SB = F.transform(CB, to_order=formatB) CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1) BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB) CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1) CA, statsA = F.vectorwise_quant(A2, dim=1)
C32A, SA = F.nvidia_transform(CA, "col32") C32A, SA = F.nvidia_transform(CA, "col32")
@ -1896,15 +1900,13 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize() torch.cuda.synchronize()
print( #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB) CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
A2 = A.view(-1, A.shape[-1]).contiguous() A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
C32A, SA = F.nvidia_transform(CA, "col32") C32A, SA = F.nvidia_transform(CA, "col32")
@ -1912,14 +1914,12 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127)) out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize() torch.cuda.synchronize()
print( #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
linear8bit(A) linear8bit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(
@ -1929,7 +1929,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(100): for i in range(iters):
linearMixedBit(A) linearMixedBit(A)
torch.cuda.synchronize() torch.cuda.synchronize()
print( print(