forked from mrq/bitsandbytes-rocm
Added fixes for the case that matmullt dim A is zero, e.g. [0, 768].
This commit is contained in:
parent
2f01865a2f
commit
451fd9506e
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import bitsandbytes.functional as F
|
import bitsandbytes.functional as F
|
||||||
|
|
||||||
|
@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, A, B, out=None, state=MatmulLtState()):
|
def forward(ctx, A, B, out=None, state=MatmulLtState()):
|
||||||
|
# default to pytorch behavior if inputs are empty
|
||||||
|
ctx.is_empty = False
|
||||||
|
if math.prod(A.shape) == 0:
|
||||||
|
ctx.is_empty = True
|
||||||
|
ctx.A = A
|
||||||
|
ctx.B = B
|
||||||
|
if A.shape[-1] == B.shape[0]:
|
||||||
|
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
||||||
|
else:
|
||||||
|
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
# 2. Quantize B
|
# 2. Quantize B
|
||||||
# 3. Matmul
|
# 3. Matmul
|
||||||
|
@ -265,6 +277,8 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.is_empty:
|
||||||
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
|
||||||
req_gradA, req_gradB = ctx.req_grads
|
req_gradA, req_gradB = ctx.req_grads
|
||||||
CAt, subA = ctx.tensors
|
CAt, subA = ctx.tensors
|
||||||
SCAt, idx = ctx.tensor_states
|
SCAt, idx = ctx.tensor_states
|
||||||
|
@ -293,7 +307,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||||
|
|
||||||
return grad_A, grad_B, None, None, None, None, None
|
return grad_A, grad_B, None, None
|
||||||
|
|
||||||
|
|
||||||
matmul = MatMul8bitLt.apply
|
matmul = MatMul8bitLt.apply
|
||||||
|
|
|
@ -4,9 +4,10 @@
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
import ctypes as ct
|
import ctypes as ct
|
||||||
import random
|
import random
|
||||||
from typing import Tuple
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from .cextension import lib, COMPILED_WITH_CUDA
|
from .cextension import lib, COMPILED_WITH_CUDA
|
||||||
|
@ -919,15 +920,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
||||||
shapeB = SB[0]
|
shapeB = SB[0]
|
||||||
dimsA = len(shapeA)
|
dimsA = len(shapeA)
|
||||||
dimsB = len(shapeB)
|
dimsB = len(shapeB)
|
||||||
|
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
|
||||||
if dimsA == 2:
|
if dimsA == 2:
|
||||||
m = shapeA[0]
|
m = shapeA[0]
|
||||||
elif dimsA == 3:
|
elif dimsA == 3:
|
||||||
m = shapeA[0]*shapeA[1]
|
m = shapeA[0]*shapeA[1]
|
||||||
|
|
||||||
if dimsB == 2:
|
rows = n = shapeB[0]
|
||||||
rows = n = shapeB[0]
|
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
|
||||||
elif dimsB == 3:
|
print(shapeA, math.prod(shapeA), math.prod(list(shapeA)))
|
||||||
rows = n = shapeB[0]*shapeB[1]
|
print('aaa')
|
||||||
|
|
||||||
|
# if the tensor is empty, return a transformed empty tensor with the right dimensions
|
||||||
|
if shapeA[0] == 0 and dimsA == 2:
|
||||||
|
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
|
||||||
|
elif shapeA[1] == 0 and dimsA == 3:
|
||||||
|
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
|
||||||
|
|
||||||
if dimsA == 2 and out is None:
|
if dimsA == 2 and out is None:
|
||||||
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
|
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
|
||||||
|
@ -984,6 +992,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
||||||
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||||
|
|
||||||
if has_error == 1:
|
if has_error == 1:
|
||||||
|
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
|
||||||
raise Exception('cublasLt ran into an error!')
|
raise Exception('cublasLt ran into an error!')
|
||||||
|
|
||||||
torch.cuda.set_device(prev_device)
|
torch.cuda.set_device(prev_device)
|
||||||
|
|
31
csrc/ops.cu
31
csrc/ops.cu
|
@ -459,8 +459,6 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
|
||||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||||
assert(threads <= tilesize);
|
assert(threads <= tilesize);
|
||||||
|
|
||||||
//cout << num_blocks << " blocks" << endl;
|
|
||||||
|
|
||||||
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
|
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
@ -473,11 +471,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
|
||||||
int tile_cols = STATS_THREADS*STATS_ITEMS;
|
int tile_cols = STATS_THREADS*STATS_ITEMS;
|
||||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||||
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
|
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
|
||||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
|
int row_tiles = (tiledRows/STATS_ROWS);
|
||||||
|
int col_tiles = (tiledCols/tile_cols);
|
||||||
|
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||||
|
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||||
|
int num_blocks = row_tiles * col_tiles;
|
||||||
|
|
||||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||||
|
|
||||||
|
|
||||||
if(nnz_threshold == 0.0)
|
if(nnz_threshold == 0.0)
|
||||||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||||
else if(nnz_threshold != 0.0)
|
else if(nnz_threshold != 0.0)
|
||||||
|
@ -494,13 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
|
||||||
int tile_rows = 16;
|
int tile_rows = 16;
|
||||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
int row_tiles = (tiledRows/tile_rows);
|
||||||
|
int col_tiles = (tiledCols/tile_cols);
|
||||||
|
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||||
|
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||||
|
int num_blocks = row_tiles * col_tiles;
|
||||||
|
|
||||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||||
|
|
||||||
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
|
|
||||||
//cout << "num blocks " << num_blocks << endl;
|
|
||||||
|
|
||||||
//cout << A << " " << out_col_normed << endl;
|
|
||||||
if(threshold > 0.0f)
|
if(threshold > 0.0f)
|
||||||
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
|
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
|
||||||
else
|
else
|
||||||
|
@ -518,7 +520,12 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
||||||
int tile_rows = 32;
|
int tile_rows = 32;
|
||||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
int row_tiles = (tiledRows/tile_rows);
|
||||||
|
int col_tiles = (tiledCols/tile_cols);
|
||||||
|
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||||
|
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||||
|
int num_blocks = row_tiles * col_tiles;
|
||||||
|
|
||||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||||
int outCols = fill_up_to_nearest_multiple(cols, 32);
|
int outCols = fill_up_to_nearest_multiple(cols, 32);
|
||||||
int outRows = fill_up_to_nearest_multiple(rows, 32);
|
int outRows = fill_up_to_nearest_multiple(rows, 32);
|
||||||
|
@ -545,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
|
|
||||||
//cout << "num blocks " << num_blocks << endl;
|
|
||||||
|
|
||||||
//cout << A << " " << out_col_normed << endl;
|
|
||||||
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
|
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
|
||||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,8 @@ str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, st
|
||||||
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
|
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
|
||||||
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
|
||||||
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
dim2 = dim2 - (dim2 % 16)
|
if dim2 > 0:
|
||||||
|
dim2 = dim2 - (dim2 % 16)
|
||||||
dim3 = dim3 - (dim3 % 16)
|
dim3 = dim3 - (dim3 % 16)
|
||||||
dim4 = dim4 - (dim4 % 16)
|
dim4 = dim4 - (dim4 % 16)
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
|
@ -179,6 +180,7 @@ dim2 = torch.randint(32,96, size=(n,)).tolist()
|
||||||
dim3 = torch.randint(32,96, size=(n,)).tolist()
|
dim3 = torch.randint(32,96, size=(n,)).tolist()
|
||||||
dim4 = torch.randint(32,96, size=(n,)).tolist()
|
dim4 = torch.randint(32,96, size=(n,)).tolist()
|
||||||
|
|
||||||
|
dim2.append(0)
|
||||||
#dim1 = (17,)
|
#dim1 = (17,)
|
||||||
#dim2 = (7,)
|
#dim2 = (7,)
|
||||||
#dim3 = (37,)
|
#dim3 = (37,)
|
||||||
|
@ -234,9 +236,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
|
||||||
err = torch.abs(out_bnb-out_torch).mean().item()
|
err = torch.abs(out_bnb-out_torch).mean().item()
|
||||||
#print(f'abs error {err:.4f}')
|
#print(f'abs error {err:.4f}')
|
||||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||||
assert (idx==0).sum().item() < n*0.0175
|
assert (idx==0).sum().item() <= n*0.0175
|
||||||
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
||||||
assert (idx==0).sum().item() < n*0.001
|
assert (idx==0).sum().item() <= n*0.001
|
||||||
|
|
||||||
if has_fp16_weights:
|
if has_fp16_weights:
|
||||||
if any(req_grad):
|
if any(req_grad):
|
||||||
|
@ -260,11 +262,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
|
||||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||||
if req_grad[1]:
|
if req_grad[1]:
|
||||||
n = gradB1.numel()
|
n = gradB1.numel()
|
||||||
assert torch.abs(gradB1).sum() > 0.0
|
if dim2 > 0:
|
||||||
assert torch.abs(gradB2).sum() > 0.0
|
assert torch.abs(gradB1).sum() > 0.0
|
||||||
|
assert torch.abs(gradB2).sum() > 0.0
|
||||||
|
else:
|
||||||
|
assert torch.abs(gradB1).sum() == 0.0
|
||||||
|
assert torch.abs(gradB2).sum() == 0.0
|
||||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||||
assert (idx==0).sum().item() < n*0.1
|
assert (idx==0).sum().item() <= n*0.1
|
||||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||||
assert (idx==0).sum().item() < n*0.02
|
assert (idx==0).sum().item() <= n*0.02
|
||||||
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
|
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user