From 75377d125e59f6ce183ff89b6231082aa70b492e Mon Sep 17 00:00:00 2001 From: Mitchell Wortsman Date: Fri, 24 Feb 2023 00:10:15 +0000 Subject: [PATCH] new experiments --- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 59 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 9c70642..5ec46b3 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2 +from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5c0d0d4..94c9aa2 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -407,6 +407,65 @@ class Linear8bitLt2(nn.Linear): return out +class Linear8bitLtMixed(nn.Linear): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__( + input_features, output_features, bias + ) + self.state = bnb.MatmulLtState() + self.index = index + + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x): + self.state.is_training = self.training + + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + # if self.bias is not None and self.bias.dtype != torch.float16: + # self.bias.data = self.bias.data.half() + + #out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias + out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias + + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB + + return out + class Linear8bitLtThresh(Linear8bitLt): def __init__(