From 451fd9506e215aa25643e9782cb7d8aed2a266cc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 11:54:01 -0700 Subject: [PATCH] Added fixes for the case that matmullt dim A is zero, e.g. [0, 768]. --- bitsandbytes/autograd/_functions.py | 16 ++++++++++++++- bitsandbytes/functional.py | 21 +++++++++++++------ csrc/ops.cu | 31 ++++++++++++++++------------- tests/test_autograd.py | 20 ++++++++++++------- 4 files changed, 60 insertions(+), 28 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 815a4f1..370ca83 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,4 +1,5 @@ import torch +import math import bitsandbytes as bnb import bitsandbytes.functional as F @@ -162,6 +163,17 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if math.prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -265,6 +277,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None req_gradA, req_gradB = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states @@ -293,7 +307,7 @@ class MatMul8bitLt(torch.autograd.Function): gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) - return grad_A, grad_B, None, None, None, None, None + return grad_A, grad_B, None, None matmul = MatMul8bitLt.apply diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0a2d557..494de1b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. import ctypes as ct import random -from typing import Tuple - +import math import torch + +from typing import Tuple from torch import Tensor from .cextension import lib, COMPILED_WITH_CUDA @@ -919,15 +920,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0]*shapeA[1] - if dimsB == 2: - rows = n = shapeB[0] - elif dimsB == 3: - rows = n = shapeB[0]*shapeB[1] + rows = n = shapeB[0] + assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + print(shapeA, math.prod(shapeA), math.prod(list(shapeA))) + print('aaa') + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row') @@ -984,6 +992,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') raise Exception('cublasLt ran into an error!') torch.cuda.set_device(prev_device) diff --git a/csrc/ops.cu b/csrc/ops.cu index b3d07c6..cfc9605 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -459,8 +459,6 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); assert(threads <= tilesize); - //cout << num_blocks << " blocks" << endl; - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -473,11 +471,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r int tile_cols = STATS_THREADS*STATS_ITEMS; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); - if(nnz_threshold == 0.0) kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) @@ -494,13 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col int tile_rows = 16; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); - //cout << cols << " " << tiledCols << " " << tiledRows << endl; - //cout << "num blocks " << num_blocks << endl; - - //cout << A << " " << out_col_normed << endl; if(threshold > 0.0f) kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); else @@ -518,7 +520,12 @@ template void transformRowToFormat(char * A, char *o int tile_rows = 32; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); int outCols = fill_up_to_nearest_multiple(cols, 32); int outRows = fill_up_to_nearest_multiple(rows, 32); @@ -545,10 +552,6 @@ template void transformRowToFormat(char * A, char *o } } - //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl; - //cout << "num blocks " << num_blocks << endl; - - //cout << A << " " << out_col_normed << endl; kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } diff --git a/tests/test_autograd.py b/tests/test_autograd.py index d2b5d59..1b6c2ab 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -23,7 +23,8 @@ str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, st names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] @pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): - dim2 = dim2 - (dim2 % 16) + if dim2 > 0: + dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) for i in range(k): @@ -179,6 +180,7 @@ dim2 = torch.randint(32,96, size=(n,)).tolist() dim3 = torch.randint(32,96, size=(n,)).tolist() dim4 = torch.randint(32,96, size=(n,)).tolist() +dim2.append(0) #dim1 = (17,) #dim2 = (7,) #dim3 = (37,) @@ -234,9 +236,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec err = torch.abs(out_bnb-out_torch).mean().item() #print(f'abs error {err:.4f}') idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.0175 + assert (idx==0).sum().item() <= n*0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() < n*0.001 + assert (idx==0).sum().item() <= n*0.001 if has_fp16_weights: if any(req_grad): @@ -260,11 +262,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() - assert torch.abs(gradB1).sum() > 0.0 - assert torch.abs(gradB2).sum() > 0.0 + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx==0).sum().item() <= n*0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx==0).sum().item() <= n*0.02 torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)