From 160a83580d3e159d00fa3004c8b98a64d08fb732 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 4 Feb 2023 21:11:21 -0800 Subject: [PATCH] Forward matmul_fp4 tests pass. --- bitsandbytes/__init__.py | 1 + bitsandbytes/autograd/_functions.py | 67 +++++++++++++++- bitsandbytes/functional.py | 15 ++-- bitsandbytes/nn/modules.py | 62 +++++++++++++++ tests/test_autograd.py | 115 ++++++++++++++++++++++++++++ tests/test_functional.py | 17 +--- 6 files changed, 254 insertions(+), 23 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 041df4b..c83b7ff 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -10,6 +10,7 @@ from .autograd._functions import ( matmul, matmul_cublas, mm_cublas, + matmul_fp4 ) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 376fb8a..a098d4b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -2,7 +2,7 @@ import operator import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional +from typing import Tuple, Optional, List import torch @@ -474,6 +474,67 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None +class MatMulFP4(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=None): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + B_shape = state[1] + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + + # 1. Dequantize + # 2. Matmul + output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias) + + # 3. Save state + ctx.state = state + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = A + else: + ctx.tensors = [None, None] + ctx.tensor_states = (None, None) + ctx.save_for_backward(None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad + A = ctx.tensors + state = ctx.state + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # Cast grad_output to fp16 + if len(grad_output.shape) == 3: + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + + if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A)) + + return grad_A, grad_B, None, grad_bias, None + + def matmul( A: tensor, B: tensor, @@ -486,3 +547,7 @@ def matmul( if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, bias, state) + + +def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None): + return MatMulFP4.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index da9e743..92ac670 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -626,7 +626,7 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize ------- torch.Tensor: The 8-bit tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype): + tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ if A.device.type != 'cuda': @@ -640,10 +640,10 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) - state = (absmax, input_shape, A.dtype) + state = (absmax, input_shape, A.dtype, blocksize) if out is None: - out = torch.zeros(((n+1)//2,), dtype=torch.uint8, device=A.device) + out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -692,7 +692,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: shape = out.shape dtype = out.dtype else: - absmax, shape, dtype = quant_state + absmax, shape, dtype, blocksize = quant_state if out is None: @@ -700,6 +700,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: n = out.numel() + device = pre_call(A.device) is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: @@ -710,9 +711,9 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - return out - - + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 45df35e..6dfb06c 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -133,6 +133,67 @@ class Embedding(torch.nn.Embedding): return emb +class FP4Params(torch.nn.Parameter): + def __new__(cls, data=None, requires_grad=True, quant_state=None): + cls.quant_state = None + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) + w_fp4, quant_state = bnb.functional.quantize_fp4(w) + self.data = w_fp4 + self.quant_state = quant_state + + return self + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): + return self.cuda(device) + else: + new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, quant_state=self.quant_state) + + return new_param + + +class LinearFP4(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.state = bnb.MatmulLtState() + self.weight = FP4Params(self.weight.data, requires_grad=False) + + def init_8bit_state(self): + pass + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, 'state', None) is None: + print('FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + out = bnb.matmul_fp(x, self.weight, bias=self.bias, state=self.weight.state) + + return out + class Int8Params(torch.nn.Parameter): def __new__( @@ -208,6 +269,7 @@ class Int8Params(torch.nn.Parameter): return new_param + class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None): diff --git a/tests/test_autograd.py b/tests/test_autograd.py index c67126d..ba75d76 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -429,3 +429,118 @@ def test_matmullt( if req_grad[2]: torch.testing.assert_allclose(gradBias1, gradBias2) + + +n = 1 +k = 3 +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() + +dim2.append(0) + +funcs = [(torch.matmul, bnb.matmul_fp4)] +str_funcs = ["matmul"] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.float32] +has_fp16_weights = [True, False] +has_bias = [True, False] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names) +def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() + torch.nn.init.xavier_uniform_(B) + B2 = B.clone() + + B2, quant_state = bnb.functional.quantize_fp4(B) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2) + + if has_bias: + out_torch += bias + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.11 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None + + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None + + if req_grad[0]: + torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + + assert (idx == 0).sum().item() <= n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.02 + torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3 + ) + + if req_grad[2]: + torch.testing.assert_allclose(gradBias1, gradBias2) diff --git a/tests/test_functional.py b/tests/test_functional.py index efdda54..e6b7b81 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2221,26 +2221,13 @@ def test_fp4_quant(): A1 = torch.randn(1024, 1024, device='cuda').half() qa, SA = F.quantize_fp4(A1, blocksize=64) A2 = F.dequantize_fp4(qa, SA) - #qa, SA = F.quantize_fp4(A1, blocksize=128) - #A2 = F.dequantize_fp4(qa, SA, blocksize=128) - - #A1 = A1.flatten().sort()[0] - #A2 = A2.flatten().sort()[0] - - #print(A1) - #print(A2) err = (A1 - A2).abs().float() relerr = (err/A1.abs().float()).mean() err = err.mean() - print(err, relerr) - - - - - #assert err.item() < 0.1 - #assert relerr.item() < 0.28 + assert err.item() < 0.1 + assert relerr.item() < 0.28 def test_bench_fp4_dequant():