From ca3236587ad285b8a43a96629516d3362045bb99 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 13 Feb 2023 17:20:52 -0800 Subject: [PATCH] Added forward/backward tests; removed bias. --- bitsandbytes/autograd/_functions.py | 36 +++++++---------- bitsandbytes/nn/modules.py | 4 +- tests/test_autograd.py | 61 +++++++++++++++-------------- 3 files changed, 48 insertions(+), 53 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index fc027f2..c2b8773 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -395,15 +395,14 @@ class MatMulFP8(torch.autograd.Function): # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod - def forward(ctx, A, B, out=None, bias=None, fw_code=None, bw_code=None): + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None): # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: ctx.is_empty = True ctx.A = A ctx.B = B - ctx.bias = bias - B_shape = state[1] + B_shape = B.shape if A.shape[-1] == B_shape[0]: return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) else: @@ -414,17 +413,17 @@ class MatMulFP8(torch.autograd.Function): # 2. MatmulnN cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024) - fp8A = F.dequantize_blockwise(cA, state) + fp8A = F.dequantize_blockwise(cA, state).to(A.dtype) cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024) - fp8B = F.dequantize_blockwise(cB, state) + fp8B = F.dequantize_blockwise(cB, state).to(B.dtype) - output = torch.nn.functional.linear(fp8A, fp8B) + output = torch.matmul(fp8A, fp8B) # 3. Save state ctx.bw_code = bw_code - ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype if any(ctx.needs_input_grad[:2]): ctx.tensors = (fp8A, fp8B) @@ -436,21 +435,15 @@ class MatMulFP8(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None - req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + req_gradA, req_gradB, _, _, _ = ctx.needs_input_grad fp8A, B = ctx.tensors - state = ctx.state - grad_A, grad_B, grad_bias = None, None, None + grad_A, grad_B = None, None - cgrad_out, state = F.quantize_blockwise(grad_ouput, code=ctx.bw_code, blocksize=1024) - fp8out = F.dequantize_blockwise(cgrad_out, state) - - if req_gradBias: - # compute grad_bias first before changing grad_output dtype - grad_bias = fp8out.sum(0, dtype=ctx.dtype_bias) + cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024) + fp8out = F.dequantize_blockwise(cgrad_out, state).to(grad_output.dtype) # Cast grad_output to fp16 if len(grad_output.shape) == 3: @@ -461,7 +454,7 @@ class MatMulFP8(torch.autograd.Function): if req_gradA: grad_A = torch.matmul(fp8out, B.t()) if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out) - return grad_A, grad_B, None, grad_bias, None, None + return grad_A, grad_B, None, None, None def matmul( @@ -478,9 +471,8 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bias=None): - assert quant_state is not None - return MatMulFP8.apply(A, B, out, bias, fw_code, bw_code) +def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None): + return MatMulFP8.apply(A, B, out, fw_code, bw_code) def matmul( diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b1d5355..5e12ddb 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -355,7 +355,9 @@ class LinearFP8(nn.Linear): self.bw_code = F.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = F.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8(x, self.weight.t(), bias=self.bias, fw_code=self.fw_code, code=self.bw_code) + out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, code=self.bw_code) + if self.bias is not None: + out += self.bias return out diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 0def35d..4d3e67a 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -456,18 +456,16 @@ transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] dtype = [torch.float16, torch.float32] has_fp16_weights = [True, False] -has_bias = [True, False] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias)) -str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias)) -names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values] @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") -@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names) -def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias): +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - if has_bias == False: - req_grad = list(req_grad) - req_grad[2] = False + req_grad = list(req_grad) + req_grad[2] = False for i in range(k): # normal multiply @@ -475,32 +473,24 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) - bias = None - bias2 = None - if has_bias: - bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) - bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) - B2, quant_state = bnb.functional.quantize_fp8(B) + fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) + bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) + out_bnb = funcs[1](A, B.t(), fw_code, bw_code) elif not transpose[0] and not transpose[1]: out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2, quant_state, bias=bias2) - - if has_bias: - out_torch += bias + out_bnb = funcs[1](A, B, fw_code, bw_code) assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: - assert err < 0.115 - + assert err < 0.20 if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() @@ -510,9 +500,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, gradB1 = B.grad A.grad = None B.grad = None - if has_bias: - gradBias1 = bias.grad - bias.grad = None loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() loss_torch.backward() @@ -520,12 +507,26 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, gradB2 = B.grad A.grad = None B.grad = None - if has_bias: - gradBias2 = bias.grad - bias.grad = None if req_grad[0]: torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1) - if req_grad[2]: - torch.testing.assert_allclose(gradBias1, gradBias2) + if req_grad[1]: + n = gradB1.numel() + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + + assert (idx == 0).sum().item() <= n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.02 + grad_err = (gradB1-gradB2).abs().mean() + assert grad_err.item() < 0.003 + torch.testing.assert_allclose( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) +