From 9e7cdc9ea95e9756d9f5621a0e2c7e2538363fae Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 12 Apr 2023 13:41:30 -0700 Subject: [PATCH] Added last SwitchBack refactors. All tests green. --- CHANGELOG.md | 7 +++++++ bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/triton_based_modules.py | 18 +++++++++--------- setup.py | 2 +- tests/test_triton.py | 16 ++++++++-------- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5399c02..2de70d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -221,3 +221,10 @@ Improvements: Deprecated: - Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. - Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0 + + +### 0.38.1 + +Features: + - Added Int8 SwitchBack layers + - Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index ec944a3..f51f600 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, SwitchBackLinearBnb -from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear +from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index 7794fa0..6fbf583 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -157,7 +157,7 @@ class SwitchBackLinear(nn.Linear): bias: bool = True, device=None, dtype=None, - vectorize: bool = False, + vector_wise_quantization: bool = False, mem_efficient : bool = False, ): super().__init__(in_features, out_features, bias, device, dtype) @@ -167,11 +167,11 @@ class SwitchBackLinear(nn.Linear): Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') # By default, we use the global quantization. - self.vectorize = vectorize - if self.vectorize: + self.vector_wise_quantization = vector_wise_quantization + if self.vector_wise_quantization: self._fn = _switchback_vectorrize if mem_efficient: - print('mem efficient is not supported for vectorize mode.') + print('mem efficient is not supported for vector-wise quantization.') exit(1) else: if mem_efficient: @@ -188,7 +188,7 @@ class SwitchBackLinear(nn.Linear): # m.prepare_for_eval() # model.apply(cond_prepare) print('=> preparing for eval.') - if self.vectorize: + if self.vector_wise_quantization: W_int8, state_W = quantize_rowwise(self.weight) else: W_int8, state_W = quantize_global(self.weight) @@ -210,7 +210,7 @@ class SwitchBackLinear(nn.Linear): X = x.view(-1, x.size(-1)) X_int8, state_X = quantize_rowwise(X) - if self.vectorize: + if self.vector_wise_quantization: return int8_matmul_rowwise_dequantize( X_int8, self.W_int8.t(), state_X, self.state_W, self.bias ).view(*x.size()[:-1], -1) @@ -219,9 +219,9 @@ class SwitchBackLinear(nn.Linear): X_int8, self.W_int8.t(), state_X, self.state_W, self.bias ).view(*x.size()[:-1], -1) -SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False) -SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vectorize=False, mem_efficient=True) -SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True) +SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) +SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) +SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) # This is just the standard linear function. class StandardLinearFunction(torch.autograd.Function): diff --git a/setup.py b/setup.py index e514463..009fd3d 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.38.0.post2", + version=f"0.38.1", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", diff --git a/tests/test_triton.py b/tests/test_triton.py index 7f56a49..e18c7a9 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -1,19 +1,19 @@ import pytest import torch +from bitsandbytes.triton.triton_utils import is_triton_available from bitsandbytes.nn.triton_based_modules import SwitchBackLinear from bitsandbytes.nn import Linear8bitLt - -@pytest.mark.skipif(not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires a GPU with compute capability 8.0 or higher.") -@pytest.mark.parametrize("vectorrize", [False, True]) -def test_switchback(vectorrize): - for dim in [83, 17, 128]: - for batch in [13, 128, 256]: +@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, + reason="This test requires triton and a GPU with compute capability 8.0 or higher.") +@pytest.mark.parametrize("vector_wise_quantization", [False, True]) +def test_switchback(vector_wise_quantization): + for dim in [83]: + for batch in [13]: standard = torch.nn.Linear(dim, 4 * dim).cuda().half() - print('vectorrize', vectorrize) - switchback = SwitchBackLinear(dim, 4 * dim, vectorize=vectorrize).cuda().half() + switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() baseline = Linear8bitLt(dim, 4 * dim).cuda().half() switchback.weight.data.copy_(standard.weight) switchback.bias.data.copy_(standard.bias)