Merge branch 'debug' into cuda-bin-switch-and-cli
This commit is contained in:
commit
758c7175a2
2
Makefile
2
Makefile
|
@ -58,7 +58,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
|||
|
||||
|
||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
import math
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
|
@ -199,6 +199,17 @@ class MatmulLtState:
|
|||
class MatMul8bitLt(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if math.prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
if A.shape[-1] == B.shape[0]:
|
||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
|
@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
|
||||
req_gradA, req_gradB = ctx.req_grads
|
||||
CAt, subA = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
|
@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
ctx.grad_shape
|
||||
)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
return grad_A, grad_B, None, None
|
||||
|
||||
|
||||
matmul = MatMul8bitLt.apply
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
import ctypes as ct
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from torch import Tensor
|
||||
|
||||
from .cextension import COMPILED_WITH_CUDA, lib
|
||||
|
@ -193,6 +194,14 @@ def get_special_format_str():
|
|||
return "col_turing"
|
||||
|
||||
|
||||
|
||||
def is_on_gpu(tensors):
|
||||
on_gpu = True
|
||||
for t in tensors:
|
||||
if t is None: continue # NULL pointers are fine
|
||||
on_gpu &= t.device.type == 'cuda'
|
||||
return on_gpu
|
||||
|
||||
def get_ptr(A: Tensor) -> ct.c_void_p:
|
||||
"""
|
||||
Get the ctypes pointer from a PyTorch Tensor.
|
||||
|
@ -336,7 +345,7 @@ def nvidia_transform(
|
|||
def estimate_quantiles(
|
||||
A: Tensor, out: Tensor = None, offset: float = 1 / 512
|
||||
) -> Tensor:
|
||||
"""
|
||||
'''
|
||||
Estimates 256 equidistant quantiles on the input tensor eCDF.
|
||||
|
||||
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
|
||||
|
@ -361,9 +370,9 @@ def estimate_quantiles(
|
|||
-------
|
||||
torch.Tensor:
|
||||
The 256 quantiles in float32 datatype.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
|
||||
'''
|
||||
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
|
||||
is_on_gpu([A, out])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cestimate_quantiles_fp32(
|
||||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
||||
|
@ -428,7 +437,8 @@ def quantize_blockwise(
|
|||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
|
||||
if A.device.type != "cpu":
|
||||
if A.device.type != 'cpu':
|
||||
is_on_gpu([code, A, absmax, out, rand])
|
||||
if rand is not None:
|
||||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
|
@ -541,7 +551,8 @@ def dequantize_blockwise(
|
|||
f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
|
||||
)
|
||||
|
||||
if A.device.type != "cpu":
|
||||
if A.device.type != 'cpu':
|
||||
is_on_gpu([A, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32(
|
||||
get_ptr(quant_state[1]),
|
||||
|
@ -610,7 +621,7 @@ def dequantize(
|
|||
|
||||
|
||||
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
||||
"""
|
||||
'''
|
||||
Quantizes input tensor to 8-bit.
|
||||
|
||||
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
|
||||
|
@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
|||
-------
|
||||
torch.Tensor:
|
||||
Quantized 8-bit tensor.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
'''
|
||||
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
is_on_gpu([A, out])
|
||||
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
||||
return out
|
||||
|
||||
|
||||
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
||||
"""
|
||||
'''
|
||||
Dequantizes the 8-bit tensor to 32-bit.
|
||||
|
||||
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
|
||||
|
@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
|||
-------
|
||||
torch.Tensor:
|
||||
32-bit output tensor.
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
lib.cdequantize(
|
||||
get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
|
||||
)
|
||||
'''
|
||||
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
|
||||
is_on_gpu([code, A, out])
|
||||
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
||||
return out
|
||||
|
||||
|
||||
|
@ -983,6 +992,7 @@ def percentile_clipping(
|
|||
The current optimiation steps (number of past gradient norms).
|
||||
|
||||
"""
|
||||
is_on_gpu([grad, gnorm_vec])
|
||||
if grad.dtype == torch.float32:
|
||||
lib.cpercentile_clipping_g32(
|
||||
get_ptr(grad),
|
||||
|
@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d(
|
|||
|
||||
maxdim1 = ct.c_int32(histogram.shape[0])
|
||||
n = ct.c_int32(index1.numel())
|
||||
lib.chistogram_scatter_add_2d(
|
||||
get_ptr(histogram),
|
||||
get_ptr(index1),
|
||||
get_ptr(index2),
|
||||
get_ptr(source),
|
||||
maxdim1,
|
||||
n,
|
||||
)
|
||||
is_on_gpu([histogram, index1, index2d, source])
|
||||
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
|
||||
|
||||
|
||||
def check_matmul(
|
||||
A, B, out, transposed_A, transposed_B, expected_type=torch.int8
|
||||
):
|
||||
if not torch.cuda.is_initialized():
|
||||
torch.cuda.init()
|
||||
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
|
||||
if not torch.cuda.is_initialized(): torch.cuda.init()
|
||||
if A.dtype != expected_type or B.dtype != expected_type:
|
||||
raise TypeError(
|
||||
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
|
||||
|
@ -1212,21 +1212,10 @@ def igemm(
|
|||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
lib.cigemm(
|
||||
ptr,
|
||||
ct.c_bool(transposed_B),
|
||||
ct.c_bool(transposed_A),
|
||||
ct.c_int32(m),
|
||||
ct.c_int32(n),
|
||||
ct.c_int32(k),
|
||||
get_ptr(B),
|
||||
get_ptr(A),
|
||||
get_ptr(out),
|
||||
ct.c_int32(lda),
|
||||
ct.c_int32(ldb),
|
||||
ct.c_int32(ldc),
|
||||
)
|
||||
# [km, nk -> mn]
|
||||
is_on_gpu([B, A, out])
|
||||
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
|
||||
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
|
||||
return out
|
||||
|
||||
|
||||
|
@ -1306,24 +1295,10 @@ def batched_igemm(
|
|||
|
||||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
|
||||
lib.cbatched_igemm(
|
||||
ptr,
|
||||
ct.c_bool(transposed_B),
|
||||
ct.c_bool(transposed_A),
|
||||
ct.c_int32(m),
|
||||
ct.c_int32(n),
|
||||
ct.c_int32(k),
|
||||
get_ptr(B),
|
||||
get_ptr(A),
|
||||
get_ptr(out),
|
||||
ct.c_int32(lda),
|
||||
ct.c_int32(ldb),
|
||||
ct.c_int32(ldc),
|
||||
ct.c_long(strideA),
|
||||
ct.c_long(strideB),
|
||||
ct.c_long(strideC),
|
||||
ct.c_uint32(num_batch),
|
||||
)
|
||||
is_on_gpu([B, A, out])
|
||||
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
|
||||
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
|
||||
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
|
||||
return out
|
||||
|
||||
|
||||
|
@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
|||
shapeB = SB[0]
|
||||
dimsA = len(shapeA)
|
||||
dimsB = len(shapeB)
|
||||
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
|
||||
if dimsA == 2:
|
||||
m = shapeA[0]
|
||||
elif dimsA == 3:
|
||||
m = shapeA[0] * shapeA[1]
|
||||
|
||||
if dimsB == 2:
|
||||
rows = n = shapeB[0]
|
||||
elif dimsB == 3:
|
||||
rows = n = shapeB[0] * shapeB[1]
|
||||
rows = n = shapeB[0]
|
||||
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
|
||||
|
||||
# if the tensor is empty, return a transformed empty tensor with the right dimensions
|
||||
if shapeA[0] == 0 and dimsA == 2:
|
||||
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
|
||||
elif shapeA[1] == 0 and dimsA == 3:
|
||||
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
|
||||
|
||||
if dimsA == 2 and out is None:
|
||||
out, Sout = get_transform_buffer(
|
||||
|
@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
|||
|
||||
has_error = 0
|
||||
ptrRowScale = get_ptr(None)
|
||||
if formatB == "col_turing":
|
||||
is_on_gpu([A, B, out])
|
||||
if formatB == 'col_turing':
|
||||
if dtype == torch.int32:
|
||||
has_error = lib.cigemmlt_turing_32(
|
||||
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
|
||||
|
@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
|||
)
|
||||
|
||||
if has_error == 1:
|
||||
raise Exception("cublasLt ran into an error!")
|
||||
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
|
||||
raise Exception('cublasLt ran into an error!')
|
||||
|
||||
torch.cuda.set_device(prev_device)
|
||||
|
||||
|
@ -1457,16 +1439,8 @@ def mm_dequant(
|
|||
numRows = ct.c_int32(out_shape[0])
|
||||
numCols = ct.c_int32(out_shape[1])
|
||||
|
||||
lib.cdequant_mm_int32_fp16(
|
||||
ptrA,
|
||||
ptrRowStats,
|
||||
ptrColStats,
|
||||
ptrOut,
|
||||
ptrNewRowStats,
|
||||
ptrNewColStats,
|
||||
numRows,
|
||||
numCols,
|
||||
)
|
||||
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
|
||||
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
|
||||
|
||||
return out
|
||||
|
||||
|
@ -1507,15 +1481,8 @@ def get_colrow_absmax(
|
|||
cols = ct.c_int32(cols)
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
lib.cget_col_row_stats(
|
||||
ptrA,
|
||||
ptrRowStats,
|
||||
ptrColStats,
|
||||
ptrNnzrows,
|
||||
ct.c_float(threshold),
|
||||
rows,
|
||||
cols,
|
||||
)
|
||||
is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
|
||||
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
|
||||
post_call(prev_device)
|
||||
|
||||
if threshold > 0.0:
|
||||
|
@ -1642,6 +1609,7 @@ def double_quant(
|
|||
ptrOutCol = get_ptr(out_col)
|
||||
ptrOutRow = get_ptr(out_row)
|
||||
|
||||
is_on_gpu([A, col_stats, row_stats, out_col, out_row])
|
||||
if threshold > 0.0:
|
||||
nnz = nnz_row_ptr[-1].item()
|
||||
if nnz > 0:
|
||||
|
@ -1714,33 +1682,19 @@ def get_special_format_str():
|
|||
)
|
||||
assert major >= 7
|
||||
|
||||
if major == 7:
|
||||
return "col_turing"
|
||||
elif major == 8:
|
||||
return "col_ampere"
|
||||
else:
|
||||
return "col_turing"
|
||||
if major == 7: return 'col_turing'
|
||||
elif major == 8: return 'col_ampere'
|
||||
else: return 'col_turing'
|
||||
|
||||
|
||||
def transform(
|
||||
A,
|
||||
to_order,
|
||||
from_order="row",
|
||||
out=None,
|
||||
transpose=False,
|
||||
state=None,
|
||||
ld=None,
|
||||
):
|
||||
if state is None:
|
||||
state = (A.shape, from_order)
|
||||
else:
|
||||
from_order = state[1]
|
||||
if out is None:
|
||||
out, new_state = get_transform_buffer(
|
||||
state[0], A.dtype, A.device, to_order, state[1], transpose
|
||||
)
|
||||
else:
|
||||
new_state = (state[0], to_order) # (shape, order)
|
||||
|
||||
|
||||
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
|
||||
prev_device = pre_call(A.device)
|
||||
if state is None: state = (A.shape, from_order)
|
||||
else: from_order = state[1]
|
||||
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
|
||||
else: new_state = (state[0], to_order) # (shape, order)
|
||||
|
||||
shape = state[0]
|
||||
if len(shape) == 2:
|
||||
|
@ -1752,7 +1706,8 @@ def transform(
|
|||
|
||||
ptrA = get_ptr(A)
|
||||
ptrOut = get_ptr(out)
|
||||
if to_order == "col32":
|
||||
is_on_gpu([A, out])
|
||||
if to_order == 'col32':
|
||||
if transpose:
|
||||
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
|
||||
else:
|
||||
|
@ -1773,9 +1728,9 @@ def transform(
|
|||
elif from_order == "col_ampere":
|
||||
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Transform function not implemented: From {from_order} to {to_order}"
|
||||
)
|
||||
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
|
||||
|
||||
post_call(prev_device)
|
||||
|
||||
return out, new_state
|
||||
|
||||
|
@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None):
|
|||
cldb = ct.c_int32(ldb)
|
||||
cldc = ct.c_int32(ldc)
|
||||
|
||||
lib.cspmm_coo(
|
||||
ptr,
|
||||
ptrRowidx,
|
||||
ptrColidx,
|
||||
ptrValues,
|
||||
cnnz,
|
||||
crowsA,
|
||||
ccolsA,
|
||||
ccolsB,
|
||||
cldb,
|
||||
ptrB,
|
||||
cldc,
|
||||
ptrC,
|
||||
ct.c_bool(transposed_B),
|
||||
)
|
||||
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
|
||||
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
|
||||
|
||||
return out
|
||||
|
||||
|
@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
|||
# print(cooA.rowidx[:64])
|
||||
# print(cooA.colidx[:64].sort()[0])
|
||||
|
||||
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
|
||||
if B.dtype == torch.float16:
|
||||
lib.cspmm_coo_very_sparse_naive_fp16(
|
||||
ptrMaxCount,
|
||||
|
@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx):
|
|||
ptrIdx = get_ptr(idx)
|
||||
ptrOut = get_ptr(out)
|
||||
|
||||
if formatA == "col_turing":
|
||||
prev_device = pre_call(A.device)
|
||||
if formatA == 'col_turing':
|
||||
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
||||
elif formatA == "col_ampere":
|
||||
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
||||
post_call(prev_device)
|
||||
|
||||
return out
|
||||
|
|
126
csrc/ops.cu
126
csrc/ops.cu
|
@ -19,53 +19,59 @@ using std::endl;
|
|||
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
|
||||
{
|
||||
int threads = 512;
|
||||
int blocks = n/threads;
|
||||
blocks = n % threads == 0 ? blocks : blocks + 1;
|
||||
kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
|
||||
int num_blocks = n/threads;
|
||||
num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
|
||||
kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
|
||||
kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n)
|
||||
{
|
||||
int blocks = n/1024;
|
||||
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||
kQuantize<<<blocks, 1024>>>(code, A, out, n);
|
||||
int num_blocks = n/1024;
|
||||
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n)
|
||||
{
|
||||
int blocks = n/1024;
|
||||
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||
kDequantize<<<blocks, 1024>>>(code, A, out, n);
|
||||
int num_blocks = n/1024;
|
||||
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
{
|
||||
int blocks = n/blocksize;
|
||||
blocks = n % blocksize == 0 ? blocks : blocks + 1;
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
if(blocksize == 4096)
|
||||
kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 2048)
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
|
@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
|
@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
float weight_decay,
|
||||
const float gnorm_scale, int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
|
||||
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
|
||||
|
||||
|
@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
case ADAM:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
|
||||
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
|
@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
|
@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
|
||||
{
|
||||
|
||||
int blocks = 0;
|
||||
int num_blocks = 0;
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
blocks = n/BLOCKSIZE_2STATE;
|
||||
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
|
||||
num_blocks = n/BLOCKSIZE_2STATE;
|
||||
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
blocks = n/BLOCKSIZE_1STATE;
|
||||
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
num_blocks = n/BLOCKSIZE_1STATE;
|
||||
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
|
@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
|
||||
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
|
||||
{
|
||||
int blocks = n/2048;
|
||||
blocks = n % 2048 == 0 ? blocks : blocks + 1;
|
||||
int num_blocks = n/2048;
|
||||
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
|
||||
kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n);
|
||||
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
|
|||
int num_blocks = numRows/subtile_rows;
|
||||
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
|
||||
num_blocks = num_blocks*(tileCols/32);
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
assert(threads <= tilesize);
|
||||
|
||||
//cout << num_blocks << " blocks" << endl;
|
||||
|
||||
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
|
|||
int tile_cols = STATS_THREADS*STATS_ITEMS;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
|
||||
int row_tiles = (tiledRows/STATS_ROWS);
|
||||
int col_tiles = (tiledCols/tile_cols);
|
||||
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||
int num_blocks = row_tiles * col_tiles;
|
||||
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
|
||||
if(nnz_threshold == 0.0)
|
||||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||
|
@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
|
|||
int tile_rows = 16;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
||||
int row_tiles = (tiledRows/tile_rows);
|
||||
int col_tiles = (tiledCols/tile_cols);
|
||||
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||
int num_blocks = row_tiles * col_tiles;
|
||||
|
||||
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
|
||||
//cout << "num blocks " << num_blocks << endl;
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
|
||||
//cout << A << " " << out_col_normed << endl;
|
||||
if(threshold > 0.0f)
|
||||
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
|
||||
else
|
||||
|
@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
|||
int tile_rows = 32;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
||||
int row_tiles = (tiledRows/tile_rows);
|
||||
int col_tiles = (tiledCols/tile_cols);
|
||||
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||
int num_blocks = row_tiles * col_tiles;
|
||||
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
int outCols = fill_up_to_nearest_multiple(cols, 32);
|
||||
int outRows = fill_up_to_nearest_multiple(rows, 32);
|
||||
if(FORMAT == COL_TURING)
|
||||
|
@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
|||
}
|
||||
}
|
||||
|
||||
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
|
||||
//cout << "num blocks " << num_blocks << endl;
|
||||
|
||||
//cout << A << " " << out_col_normed << endl;
|
||||
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
|
|
@ -40,7 +40,8 @@ names = [
|
|||
ids=names,
|
||||
)
|
||||
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||
dim2 = dim2 - (dim2 % 16)
|
||||
if dim2 > 0:
|
||||
dim2 = dim2 - (dim2 % 16)
|
||||
dim3 = dim3 - (dim3 % 16)
|
||||
dim4 = dim4 - (dim4 % 16)
|
||||
for i in range(k):
|
||||
|
@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
|
|||
dim3 = torch.randint(32, 96, size=(n,)).tolist()
|
||||
dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
||||
|
||||
# dim1 = (17,)
|
||||
# dim2 = (7,)
|
||||
# dim3 = (37,)
|
||||
# dim4 = (23,)
|
||||
dim2.append(0)
|
||||
|
||||
decomp = [0.0, 6.0]
|
||||
funcs = [(torch.matmul, bnb.matmul)]
|
||||
|
@ -385,9 +383,14 @@ def test_matmullt(
|
|||
)
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
if dim2 > 0:
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
else:
|
||||
assert torch.abs(gradB1).sum() == 0.0
|
||||
assert torch.abs(gradB2).sum() == 0.0
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
|
||||
assert (idx == 0).sum().item() < n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() < n * 0.02
|
||||
|
|
Loading…
Reference in New Issue
Block a user