Added last SwitchBack refactors. All tests green.
This commit is contained in:
parent
008dfff9b4
commit
9e7cdc9ea9
|
@ -221,3 +221,10 @@ Improvements:
|
||||||
Deprecated:
|
Deprecated:
|
||||||
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
|
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
|
||||||
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
|
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
|
||||||
|
|
||||||
|
|
||||||
|
### 0.38.1
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Added Int8 SwitchBack layers
|
||||||
|
- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`)
|
||||||
|
|
|
@ -3,4 +3,4 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, SwitchBackLinearBnb
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, SwitchBackLinearBnb
|
||||||
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear
|
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear
|
||||||
|
|
|
@ -157,7 +157,7 @@ class SwitchBackLinear(nn.Linear):
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
vectorize: bool = False,
|
vector_wise_quantization: bool = False,
|
||||||
mem_efficient : bool = False,
|
mem_efficient : bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
@ -167,11 +167,11 @@ class SwitchBackLinear(nn.Linear):
|
||||||
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
|
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
|
||||||
|
|
||||||
# By default, we use the global quantization.
|
# By default, we use the global quantization.
|
||||||
self.vectorize = vectorize
|
self.vector_wise_quantization = vector_wise_quantization
|
||||||
if self.vectorize:
|
if self.vector_wise_quantization:
|
||||||
self._fn = _switchback_vectorrize
|
self._fn = _switchback_vectorrize
|
||||||
if mem_efficient:
|
if mem_efficient:
|
||||||
print('mem efficient is not supported for vectorize mode.')
|
print('mem efficient is not supported for vector-wise quantization.')
|
||||||
exit(1)
|
exit(1)
|
||||||
else:
|
else:
|
||||||
if mem_efficient:
|
if mem_efficient:
|
||||||
|
@ -188,7 +188,7 @@ class SwitchBackLinear(nn.Linear):
|
||||||
# m.prepare_for_eval()
|
# m.prepare_for_eval()
|
||||||
# model.apply(cond_prepare)
|
# model.apply(cond_prepare)
|
||||||
print('=> preparing for eval.')
|
print('=> preparing for eval.')
|
||||||
if self.vectorize:
|
if self.vector_wise_quantization:
|
||||||
W_int8, state_W = quantize_rowwise(self.weight)
|
W_int8, state_W = quantize_rowwise(self.weight)
|
||||||
else:
|
else:
|
||||||
W_int8, state_W = quantize_global(self.weight)
|
W_int8, state_W = quantize_global(self.weight)
|
||||||
|
@ -210,7 +210,7 @@ class SwitchBackLinear(nn.Linear):
|
||||||
X = x.view(-1, x.size(-1))
|
X = x.view(-1, x.size(-1))
|
||||||
X_int8, state_X = quantize_rowwise(X)
|
X_int8, state_X = quantize_rowwise(X)
|
||||||
|
|
||||||
if self.vectorize:
|
if self.vector_wise_quantization:
|
||||||
return int8_matmul_rowwise_dequantize(
|
return int8_matmul_rowwise_dequantize(
|
||||||
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
||||||
).view(*x.size()[:-1], -1)
|
).view(*x.size()[:-1], -1)
|
||||||
|
@ -219,9 +219,9 @@ class SwitchBackLinear(nn.Linear):
|
||||||
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
||||||
).view(*x.size()[:-1], -1)
|
).view(*x.size()[:-1], -1)
|
||||||
|
|
||||||
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
|
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
|
||||||
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vectorize=False, mem_efficient=True)
|
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
|
||||||
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
|
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
|
||||||
|
|
||||||
# This is just the standard linear function.
|
# This is just the standard linear function.
|
||||||
class StandardLinearFunction(torch.autograd.Function):
|
class StandardLinearFunction(torch.autograd.Function):
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -18,7 +18,7 @@ def read(fname):
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=f"bitsandbytes",
|
name=f"bitsandbytes",
|
||||||
version=f"0.38.0.post2",
|
version=f"0.38.1",
|
||||||
author="Tim Dettmers",
|
author="Tim Dettmers",
|
||||||
author_email="dettmers@cs.washington.edu",
|
author_email="dettmers@cs.washington.edu",
|
||||||
description="8-bit optimizers and matrix multiplication routines.",
|
description="8-bit optimizers and matrix multiplication routines.",
|
||||||
|
|
|
@ -1,19 +1,19 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
|
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
|
||||||
from bitsandbytes.nn import Linear8bitLt
|
from bitsandbytes.nn import Linear8bitLt
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires a GPU with compute capability 8.0 or higher.")
|
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
|
||||||
@pytest.mark.parametrize("vectorrize", [False, True])
|
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
|
||||||
def test_switchback(vectorrize):
|
def test_switchback(vector_wise_quantization):
|
||||||
for dim in [83, 17, 128]:
|
for dim in [83]:
|
||||||
for batch in [13, 128, 256]:
|
for batch in [13]:
|
||||||
|
|
||||||
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
|
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
|
||||||
print('vectorrize', vectorrize)
|
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
|
||||||
switchback = SwitchBackLinear(dim, 4 * dim, vectorize=vectorrize).cuda().half()
|
|
||||||
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
|
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
|
||||||
switchback.weight.data.copy_(standard.weight)
|
switchback.weight.data.copy_(standard.weight)
|
||||||
switchback.bias.data.copy_(standard.bias)
|
switchback.bias.data.copy_(standard.bias)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user