Added fixes for the case that matmullt dim A is zero, e.g. [0, 768].
This commit is contained in:
parent
2f01865a2f
commit
451fd9506e
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import math
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
|
@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if math.prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
if A.shape[-1] == B.shape[0]:
|
||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
|
@ -265,6 +277,8 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
|
||||
req_gradA, req_gradB = ctx.req_grads
|
||||
CAt, subA = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
|
@ -293,7 +307,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
return grad_A, grad_B, None, None
|
||||
|
||||
|
||||
matmul = MatMul8bitLt.apply
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
import ctypes as ct
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
from torch import Tensor
|
||||
|
||||
from .cextension import lib, COMPILED_WITH_CUDA
|
||||
|
@ -919,15 +920,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
|||
shapeB = SB[0]
|
||||
dimsA = len(shapeA)
|
||||
dimsB = len(shapeB)
|
||||
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
|
||||
if dimsA == 2:
|
||||
m = shapeA[0]
|
||||
elif dimsA == 3:
|
||||
m = shapeA[0]*shapeA[1]
|
||||
|
||||
if dimsB == 2:
|
||||
rows = n = shapeB[0]
|
||||
elif dimsB == 3:
|
||||
rows = n = shapeB[0]*shapeB[1]
|
||||
rows = n = shapeB[0]
|
||||
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
|
||||
print(shapeA, math.prod(shapeA), math.prod(list(shapeA)))
|
||||
print('aaa')
|
||||
|
||||
# if the tensor is empty, return a transformed empty tensor with the right dimensions
|
||||
if shapeA[0] == 0 and dimsA == 2:
|
||||
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
|
||||
elif shapeA[1] == 0 and dimsA == 3:
|
||||
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
|
||||
|
||||
if dimsA == 2 and out is None:
|
||||
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
|
||||
|
@ -984,6 +992,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
|||
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
|
||||
|
||||
if has_error == 1:
|
||||
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
|
||||
raise Exception('cublasLt ran into an error!')
|
||||
|
||||
torch.cuda.set_device(prev_device)
|
||||
|
|
31
csrc/ops.cu
31
csrc/ops.cu
|
@ -459,8 +459,6 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
|
|||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
assert(threads <= tilesize);
|
||||
|
||||
//cout << num_blocks << " blocks" << endl;
|
||||
|
||||
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
@ -473,11 +471,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
|
|||
int tile_cols = STATS_THREADS*STATS_ITEMS;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
|
||||
int row_tiles = (tiledRows/STATS_ROWS);
|
||||
int col_tiles = (tiledCols/tile_cols);
|
||||
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||
int num_blocks = row_tiles * col_tiles;
|
||||
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
|
||||
|
||||
if(nnz_threshold == 0.0)
|
||||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||
else if(nnz_threshold != 0.0)
|
||||
|
@ -494,13 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
|
|||
int tile_rows = 16;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
||||
int row_tiles = (tiledRows/tile_rows);
|
||||
int col_tiles = (tiledCols/tile_cols);
|
||||
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||
int num_blocks = row_tiles * col_tiles;
|
||||
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
|
||||
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
|
||||
//cout << "num blocks " << num_blocks << endl;
|
||||
|
||||
//cout << A << " " << out_col_normed << endl;
|
||||
if(threshold > 0.0f)
|
||||
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
|
||||
else
|
||||
|
@ -518,7 +520,12 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
|||
int tile_rows = 32;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
||||
int row_tiles = (tiledRows/tile_rows);
|
||||
int col_tiles = (tiledCols/tile_cols);
|
||||
row_tiles = row_tiles > 0 ? row_tiles : 1;
|
||||
col_tiles = col_tiles > 0 ? col_tiles : 1;
|
||||
int num_blocks = row_tiles * col_tiles;
|
||||
|
||||
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
|
||||
int outCols = fill_up_to_nearest_multiple(cols, 32);
|
||||
int outRows = fill_up_to_nearest_multiple(rows, 32);
|
||||
|
@ -545,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
|||
}
|
||||
}
|
||||
|
||||
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
|
||||
//cout << "num blocks " << num_blocks << endl;
|
||||
|
||||
//cout << A << " " << out_col_normed << endl;
|
||||
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
|
|
@ -23,7 +23,8 @@ str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, st
|
|||
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
|
||||
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
|
||||
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||
dim2 = dim2 - (dim2 % 16)
|
||||
if dim2 > 0:
|
||||
dim2 = dim2 - (dim2 % 16)
|
||||
dim3 = dim3 - (dim3 % 16)
|
||||
dim4 = dim4 - (dim4 % 16)
|
||||
for i in range(k):
|
||||
|
@ -179,6 +180,7 @@ dim2 = torch.randint(32,96, size=(n,)).tolist()
|
|||
dim3 = torch.randint(32,96, size=(n,)).tolist()
|
||||
dim4 = torch.randint(32,96, size=(n,)).tolist()
|
||||
|
||||
dim2.append(0)
|
||||
#dim1 = (17,)
|
||||
#dim2 = (7,)
|
||||
#dim3 = (37,)
|
||||
|
@ -234,9 +236,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
|
|||
err = torch.abs(out_bnb-out_torch).mean().item()
|
||||
#print(f'abs error {err:.4f}')
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||
assert (idx==0).sum().item() < n*0.0175
|
||||
assert (idx==0).sum().item() <= n*0.0175
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
||||
assert (idx==0).sum().item() < n*0.001
|
||||
assert (idx==0).sum().item() <= n*0.001
|
||||
|
||||
if has_fp16_weights:
|
||||
if any(req_grad):
|
||||
|
@ -260,11 +262,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
|
|||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
if dim2 > 0:
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
else:
|
||||
assert torch.abs(gradB1).sum() == 0.0
|
||||
assert torch.abs(gradB2).sum() == 0.0
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.1
|
||||
assert (idx==0).sum().item() <= n*0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.02
|
||||
assert (idx==0).sum().item() <= n*0.02
|
||||
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user