Merge branch 'debug' into cuda-bin-switch-and-cli

This commit is contained in:
Tim Dettmers 2022-08-04 08:03:00 -07:00
commit 758c7175a2
5 changed files with 180 additions and 199 deletions

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
import math
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
@ -199,6 +199,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if math.prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
# 1. Quantize A # 1. Quantize A
# 2. Quantize B # 2. Quantize B
# 3. Matmul # 3. Matmul
@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.grad_shape ctx.grad_shape
) )
return grad_A, grad_B, None, None, None, None, None return grad_A, grad_B, None, None
matmul = MatMul8bitLt.apply matmul = MatMul8bitLt.apply

View File

@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import random import random
from typing import Tuple import math
import torch import torch
from typing import Tuple
from torch import Tensor from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
@ -193,6 +194,14 @@ def get_special_format_str():
return "col_turing" return "col_turing"
def is_on_gpu(tensors):
on_gpu = True
for t in tensors:
if t is None: continue # NULL pointers are fine
on_gpu &= t.device.type == 'cuda'
return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p: def get_ptr(A: Tensor) -> ct.c_void_p:
""" """
Get the ctypes pointer from a PyTorch Tensor. Get the ctypes pointer from a PyTorch Tensor.
@ -336,7 +345,7 @@ def nvidia_transform(
def estimate_quantiles( def estimate_quantiles(
A: Tensor, out: Tensor = None, offset: float = 1 / 512 A: Tensor, out: Tensor = None, offset: float = 1 / 512
) -> Tensor: ) -> Tensor:
""" '''
Estimates 256 equidistant quantiles on the input tensor eCDF. Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
@ -361,9 +370,9 @@ def estimate_quantiles(
------- -------
torch.Tensor: torch.Tensor:
The 256 quantiles in float32 datatype. The 256 quantiles in float32 datatype.
""" '''
if out is None: if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out])
if A.dtype == torch.float32: if A.dtype == torch.float32:
lib.cestimate_quantiles_fp32( lib.cestimate_quantiles_fp32(
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
@ -428,7 +437,8 @@ def quantize_blockwise(
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.uint8) out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != "cpu": if A.device.type != 'cpu':
is_on_gpu([code, A, absmax, out, rand])
if rand is not None: if rand is not None:
assert rand.numel() >= 1024 assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023) rand_offset = random.randint(0, 1023)
@ -541,7 +551,8 @@ def dequantize_blockwise(
f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
) )
if A.device.type != "cpu": if A.device.type != 'cpu':
is_on_gpu([A, out])
if out.dtype == torch.float32: if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32( lib.cdequantize_blockwise_fp32(
get_ptr(quant_state[1]), get_ptr(quant_state[1]),
@ -610,7 +621,7 @@ def dequantize(
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
""" '''
Quantizes input tensor to 8-bit. Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
------- -------
torch.Tensor: torch.Tensor:
Quantized 8-bit tensor. Quantized 8-bit tensor.
""" '''
if out is None: if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
out = torch.zeros_like(A, dtype=torch.uint8) is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
return out return out
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
""" '''
Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
------- -------
torch.Tensor: torch.Tensor:
32-bit output tensor. 32-bit output tensor.
""" '''
if out is None: if out is None: out = torch.zeros_like(A, dtype=torch.float32)
out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out])
lib.cdequantize( lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
)
return out return out
@ -983,6 +992,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms). The current optimiation steps (number of past gradient norms).
""" """
is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32( lib.cpercentile_clipping_g32(
get_ptr(grad), get_ptr(grad),
@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d(
maxdim1 = ct.c_int32(histogram.shape[0]) maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel()) n = ct.c_int32(index1.numel())
lib.chistogram_scatter_add_2d( is_on_gpu([histogram, index1, index2d, source])
get_ptr(histogram), lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
get_ptr(index1),
get_ptr(index2),
get_ptr(source),
maxdim1,
n,
)
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
def check_matmul( if not torch.cuda.is_initialized(): torch.cuda.init()
A, B, out, transposed_A, transposed_B, expected_type=torch.int8
):
if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type: if A.dtype != expected_type or B.dtype != expected_type:
raise TypeError( raise TypeError(
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
@ -1213,20 +1213,9 @@ def igemm(
# B^T @ A^T = C^T # B^T @ A^T = C^T
# [km, nk -> mn] # [km, nk -> mn]
lib.cigemm( is_on_gpu([B, A, out])
ptr, lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
ct.c_bool(transposed_B), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
ct.c_bool(transposed_A),
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
)
return out return out
@ -1306,24 +1295,10 @@ def batched_igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device) ptr = CUBLAS_Context.get_instance().get_context(A.device)
lib.cbatched_igemm( is_on_gpu([B, A, out])
ptr, lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
ct.c_bool(transposed_B), get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
ct.c_bool(transposed_A), ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
ct.c_int32(m),
ct.c_int32(n),
ct.c_int32(k),
get_ptr(B),
get_ptr(A),
get_ptr(out),
ct.c_int32(lda),
ct.c_int32(ldb),
ct.c_int32(ldc),
ct.c_long(strideA),
ct.c_long(strideB),
ct.c_long(strideC),
ct.c_uint32(num_batch),
)
return out return out
@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB = SB[0] shapeB = SB[0]
dimsA = len(shapeA) dimsA = len(shapeA)
dimsB = len(shapeB) dimsB = len(shapeB)
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
if dimsA == 2: if dimsA == 2:
m = shapeA[0] m = shapeA[0]
elif dimsA == 3: elif dimsA == 3:
m = shapeA[0] * shapeA[1] m = shapeA[0] * shapeA[1]
if dimsB == 2: rows = n = shapeB[0]
rows = n = shapeB[0] assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
elif dimsB == 3:
rows = n = shapeB[0] * shapeB[1] # if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2:
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
elif shapeA[1] == 0 and dimsA == 3:
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
if dimsA == 2 and out is None: if dimsA == 2 and out is None:
out, Sout = get_transform_buffer( out, Sout = get_transform_buffer(
@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error = 0 has_error = 0
ptrRowScale = get_ptr(None) ptrRowScale = get_ptr(None)
if formatB == "col_turing": is_on_gpu([A, B, out])
if formatB == 'col_turing':
if dtype == torch.int32: if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32( has_error = lib.cigemmlt_turing_32(
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
) )
if has_error == 1: if has_error == 1:
raise Exception("cublasLt ran into an error!") print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
raise Exception('cublasLt ran into an error!')
torch.cuda.set_device(prev_device) torch.cuda.set_device(prev_device)
@ -1457,16 +1439,8 @@ def mm_dequant(
numRows = ct.c_int32(out_shape[0]) numRows = ct.c_int32(out_shape[0])
numCols = ct.c_int32(out_shape[1]) numCols = ct.c_int32(out_shape[1])
lib.cdequant_mm_int32_fp16( is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats])
ptrA, lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
numRows,
numCols,
)
return out return out
@ -1507,15 +1481,8 @@ def get_colrow_absmax(
cols = ct.c_int32(cols) cols = ct.c_int32(cols)
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
lib.cget_col_row_stats( is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
ptrA, lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
ptrRowStats,
ptrColStats,
ptrNnzrows,
ct.c_float(threshold),
rows,
cols,
)
post_call(prev_device) post_call(prev_device)
if threshold > 0.0: if threshold > 0.0:
@ -1642,6 +1609,7 @@ def double_quant(
ptrOutCol = get_ptr(out_col) ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row) ptrOutRow = get_ptr(out_row)
is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0: if threshold > 0.0:
nnz = nnz_row_ptr[-1].item() nnz = nnz_row_ptr[-1].item()
if nnz > 0: if nnz > 0:
@ -1714,33 +1682,19 @@ def get_special_format_str():
) )
assert major >= 7 assert major >= 7
if major == 7: if major == 7: return 'col_turing'
return "col_turing" elif major == 8: return 'col_ampere'
elif major == 8: else: return 'col_turing'
return "col_ampere"
else:
return "col_turing"
def transform(
A,
to_order, def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
from_order="row", prev_device = pre_call(A.device)
out=None, if state is None: state = (A.shape, from_order)
transpose=False, else: from_order = state[1]
state=None, if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
ld=None, else: new_state = (state[0], to_order) # (shape, order)
):
if state is None:
state = (A.shape, from_order)
else:
from_order = state[1]
if out is None:
out, new_state = get_transform_buffer(
state[0], A.dtype, A.device, to_order, state[1], transpose
)
else:
new_state = (state[0], to_order) # (shape, order)
shape = state[0] shape = state[0]
if len(shape) == 2: if len(shape) == 2:
@ -1752,7 +1706,8 @@ def transform(
ptrA = get_ptr(A) ptrA = get_ptr(A)
ptrOut = get_ptr(out) ptrOut = get_ptr(out)
if to_order == "col32": is_on_gpu([A, out])
if to_order == 'col32':
if transpose: if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
else: else:
@ -1773,9 +1728,9 @@ def transform(
elif from_order == "col_ampere": elif from_order == "col_ampere":
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else: else:
raise NotImplementedError( raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
f"Transform function not implemented: From {from_order} to {to_order}"
) post_call(prev_device)
return out, new_state return out, new_state
@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None):
cldb = ct.c_int32(ldb) cldb = ct.c_int32(ldb)
cldc = ct.c_int32(ldc) cldc = ct.c_int32(ldc)
lib.cspmm_coo( is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
ptr, lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
ptrRowidx,
ptrColidx,
ptrValues,
cnnz,
crowsA,
ccolsA,
ccolsB,
cldb,
ptrB,
cldc,
ptrC,
ct.c_bool(transposed_B),
)
return out return out
@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
# print(cooA.rowidx[:64]) # print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0]) # print(cooA.colidx[:64].sort()[0])
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
if B.dtype == torch.float16: if B.dtype == torch.float16:
lib.cspmm_coo_very_sparse_naive_fp16( lib.cspmm_coo_very_sparse_naive_fp16(
ptrMaxCount, ptrMaxCount,
@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx):
ptrIdx = get_ptr(idx) ptrIdx = get_ptr(idx)
ptrOut = get_ptr(out) ptrOut = get_ptr(out)
if formatA == "col_turing": prev_device = pre_call(A.device)
if formatA == 'col_turing':
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
elif formatA == "col_ampere": elif formatA == "col_ampere":
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
post_call(prev_device)
return out return out

View File

@ -19,53 +19,59 @@ using std::endl;
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{ {
int threads = 512; int threads = 512;
int blocks = n/threads; int num_blocks = n/threads;
blocks = n % threads == 0 ? blocks : blocks + 1; num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n) template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n); kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void quantize(float *code, float *A, unsigned char *out, int n) void quantize(float *code, float *A, unsigned char *out, int n)
{ {
int blocks = n/1024; int num_blocks = n/1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kQuantize<<<blocks, 1024>>>(code, A, out, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void dequantize(float *code, unsigned char *A, float *out, int n) void dequantize(float *code, unsigned char *A, float *out, int n)
{ {
int blocks = n/1024; int num_blocks = n/1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kDequantize<<<blocks, 1024>>>(code, A, out, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{ {
int blocks = n/blocksize; int num_blocks = n/blocksize;
blocks = n % blocksize == 0 ? blocks : blocks + 1; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(blocksize == 4096) if(blocksize == 4096)
kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048) else if(blocksize == 2048)
kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n); kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case ADAM: case ADAM:
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if(max_unorm > 0.0f) if(max_unorm > 0.0f)
{ {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
} }
@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float weight_decay, float weight_decay,
const float gnorm_scale, int n) const float gnorm_scale, int n)
{ {
int blocks = n/4096; int num_blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case ADAM: case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{ {
int blocks = 0; int num_blocks = 0;
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case ADAM: case ADAM:
blocks = n/BLOCKSIZE_2STATE; num_blocks = n/BLOCKSIZE_2STATE;
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr, assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
blocks = n/BLOCKSIZE_1STATE; num_blocks = n/BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n) template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{ {
int blocks = n/2048; int num_blocks = n/2048;
blocks = n % 2048 == 0 ? blocks : blocks + 1; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n); kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
int num_blocks = numRows/subtile_rows; int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32); num_blocks = num_blocks*(tileCols/32);
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
assert(threads <= tilesize); assert(threads <= tilesize);
//cout << num_blocks << " blocks" << endl;
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int tile_cols = STATS_THREADS*STATS_ITEMS; int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); int row_tiles = (tiledRows/STATS_ROWS);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
if(nnz_threshold == 0.0) if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int tile_rows = 16; int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
//cout << cols << " " << tiledCols << " " << tiledRows << endl; assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
if(threshold > 0.0f) if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else else
@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int tile_rows = 32; int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); int row_tiles = (tiledRows/tile_rows);
int col_tiles = (tiledCols/tile_cols);
row_tiles = row_tiles > 0 ? row_tiles : 1;
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
int outCols = fill_up_to_nearest_multiple(cols, 32); int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32); int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING) if(FORMAT == COL_TURING)
@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
} }
} }
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols); kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }

View File

@ -40,7 +40,8 @@ names = [
ids=names, ids=names,
) )
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16) if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16) dim4 = dim4 - (dim4 % 16)
for i in range(k): for i in range(k):
@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist()
# dim1 = (17,) dim2.append(0)
# dim2 = (7,)
# dim3 = (37,)
# dim4 = (23,)
decomp = [0.0, 6.0] decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)] funcs = [(torch.matmul, bnb.matmul)]
@ -385,9 +383,14 @@ def test_matmullt(
) )
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
assert torch.abs(gradB1).sum() > 0.0 if dim2 > 0:
assert torch.abs(gradB2).sum() > 0.0 assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02