Merge branch 'debug' into cuda-bin-switch-and-cli

This commit is contained in:
Tim Dettmers 2022-08-04 08:03:00 -07:00
commit 758c7175a2
5 changed files with 180 additions and 199 deletions

View File

@ -58,7 +58,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
import torch
import math
import bitsandbytes as bnb
import bitsandbytes.functional as F
@ -199,6 +199,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if math.prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.grad_shape
)
return grad_A, grad_B, None, None, None, None, None
return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply

View File

@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import random
from typing import Tuple
import math
import torch
from typing import Tuple
from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib
@ -193,6 +194,14 @@ def get_special_format_str():
return "col_turing"
def is_on_gpu(tensors):
on_gpu = True
for t in tensors:
if t is None: continue # NULL pointers are fine
on_gpu &= t.device.type == 'cuda'
return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p:
"""
Get the ctypes pointer from a PyTorch Tensor.
@ -336,7 +345,7 @@ def nvidia_transform(
def estimate_quantiles(
A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor:
"""
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
@ -361,9 +370,9 @@ def estimate_quantiles(
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
"""
if out is None:
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
'''
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out])
if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
@ -428,7 +437,8 @@ def quantize_blockwise(
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != "cpu":
if A.device.type != 'cpu':
is_on_gpu([code, A, absmax, out, rand])
if rand is not None:
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@ -541,7 +551,8 @@ def dequantize_blockwise(
f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
)
if A.device.type != "cpu":
if A.device.type != 'cpu':
is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(
get_ptr(quant_state[1]),
@ -610,7 +621,7 @@ def dequantize(
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
'''
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
-------
torch.Tensor:
Quantized 8-bit tensor.
"""
if out is None:
out = torch.zeros_like(A, dtype=torch.uint8)
'''
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
'''
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
-------
torch.Tensor:
32-bit output tensor.
"""
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
lib.cdequantize(
get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
)
'''
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out
@ -983,6 +992,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms).
"""
is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(
get_ptr(grad),
@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d(
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
lib.chistogram_scatter_add_2d(
get_ptr(histogram),
get_ptr(index1),
get_ptr(index2),
get_ptr(source),
maxdim1,
n,
)
is_on_gpu([histogram, index1, index2d, source])
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
def check_matmul(
A, B, out, transposed_A, transposed_B, expected_type=torch.int8
):
if not torch.cuda.is_initialized():
torch.cuda.init()
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
if not torch.cuda.is_initialized(): torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError(
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
@ -1212,21 +1212,10 @@ def igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
# [km, nk -> mn]
lib.cigemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
)
# [km, nk -> mn]
is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
return out
@ -1306,24 +1295,10 @@ def batched_igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device)
lib.cbatched_igemm(
ptr,
ct.c_bool(transposed_B),
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
ct.c_long(strideA),
ct.c_long(strideB),
ct.c_long(strideC),
ct.c_uint32(num_batch),
)
is_on_gpu([B, A, out])
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
return out
@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0]
dimsA = len(shapeA)
dimsB = len(shapeB)
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0] * shapeA[1]
if dimsB == 2:
rows = n = shapeB[0]
elif dimsB == 3:
rows = n = shapeB[0] * shapeB[1]
rows = n = shapeB[0]
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
# if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2:
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
elif shapeA[1] == 0 and dimsA == 3:
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer(
@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0
ptrRowScale = get_ptr(None)
if formatB == "col_turing":
is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
)
if has_error == 1:
raise Exception("cublasLt ran into an error!")
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device)
@ -1457,16 +1439,8 @@ def mm_dequant(
numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1])
lib.cdequant_mm_int32_fp16(
ptrA,
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
numRows,
numCols,
)
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
return out
@ -1507,15 +1481,8 @@ def get_colrow_absmax(
cols = ct.c_int32(cols)
prev_device = pre_call(A.device)
lib.cget_col_row_stats(
ptrA,
ptrRowStats,
ptrColStats,
ptrNnzrows,
ct.c_float(threshold),
rows,
cols,
)
is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
if threshold > 0.0:
@ -1642,6 +1609,7 @@ def double_quant(
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)
is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0:
nnz = nnz_row_ptr[-1].item()
if nnz > 0:
@ -1714,33 +1682,19 @@ def get_special_format_str():
)
assert major >= 7
if major == 7:
return "col_turing"
elif major == 8:
return "col_ampere"
else:
return "col_turing"
if major == 7: return 'col_turing'
elif major == 8: return 'col_ampere'
else: return 'col_turing'
def transform(
A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
):
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1], transpose
)
else:
new_state = (state[0], to_order) # (shape, order)
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None: state = (A.shape, from_order)
else: from_order = state[1]
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
else: new_state = (state[0], to_order) # (shape, order)
shape = state[0]
if len(shape) == 2:
@ -1752,7 +1706,8 @@ def transform(
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
if to_order == "col32":
is_on_gpu([A, out])
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else:
@ -1773,9 +1728,9 @@ def transform(
elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(
f"Transform function not implemented: From {from_order} to {to_order}"
)
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
post_call(prev_device)
return out, new_state
@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc)
lib.cspmm_coo(
ptr,
ptrRowidx,
ptrColidx,
ptrValues,
cnnz,
crowsA,
ccolsA,
ccolsB,
cldb,
ptrB,
cldc,
ptrC,
ct.c_bool(transposed_B),
)
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
return out
@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
# print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0])
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16(
ptrMaxCount,
@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx):
ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out)
if formatA == "col_turing":
prev_device = pre_call(A.device)
if formatA == 'col_turing':
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
post_call(prev_device)
return out

View File

@ -19,53 +19,59 @@ using std::endl;
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
int threads = 512;
int blocks = n/threads;
blocks = n % threads == 0 ? blocks : blocks + 1;
kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
int num_blocks = n/threads;
num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{
int blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void quantize(float *code, float *A, unsigned char *out, int n)
{
int blocks = n/1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1;
kQuantize<<<blocks, 1024>>>(code, A, out, n);
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n)
{
int blocks = n/1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1;
kDequantize<<<blocks, 1024>>>(code, A, out, n);
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
{
int blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{
int blocks = n/blocksize;
blocks = n % blocksize == 0 ? blocks : blocks + 1;
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(blocksize == 4096)
kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n);
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048)
kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n);
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
int blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
switch(OPTIMIZER)
{
case ADAM:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float weight_decay,
const float gnorm_scale, int n)
{
int blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
int blocks = 0;
int num_blocks = 0;
switch(OPTIMIZER)
{
case ADAM:
blocks = n/BLOCKSIZE_2STATE;
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
num_blocks = n/BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
blocks = n/BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
int blocks = n/2048;
blocks = n % 2048 == 0 ? blocks : blocks + 1;
int num_blocks = n/2048;
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n);
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32);
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
assert(threads <= tilesize);
//cout << num_blocks << " blocks" << endl;
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
int row_tiles = (tiledRows/STATS_ROWS);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
//cout << "num blocks " << num_blocks << endl;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
//cout << A << " " << out_col_normed << endl;
if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else
@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING)
@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
}
}
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

View File

@ -40,7 +40,8 @@ names = [
ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16)
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
# dim1 = (17,)
# dim2 = (7,)
# dim3 = (37,)
# dim4 = (23,)
dim2.append(0)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
@ -385,9 +383,14 @@ def test_matmullt(
)
if req_grad[1]:
n = gradB1.numel()
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02