new experiments
This commit is contained in:
parent
5d2e23e8d6
commit
75377d125e
|
@ -2,4 +2,4 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed
|
||||||
|
|
|
@ -407,6 +407,65 @@ class Linear8bitLt2(nn.Linear):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Linear8bitLtMixed(nn.Linear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_features,
|
||||||
|
output_features,
|
||||||
|
bias=True,
|
||||||
|
has_fp16_weights=True,
|
||||||
|
memory_efficient_backward=False,
|
||||||
|
threshold=0.0,
|
||||||
|
index=None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
input_features, output_features, bias
|
||||||
|
)
|
||||||
|
self.state = bnb.MatmulLtState()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
self.state.threshold = threshold
|
||||||
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
|
self.state.memory_efficient_backward = memory_efficient_backward
|
||||||
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
|
self.state.use_pool = True
|
||||||
|
|
||||||
|
self.weight = Int8Params(
|
||||||
|
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_8bit_state(self):
|
||||||
|
self.state.CB = self.weight.CB
|
||||||
|
self.state.SCB = self.weight.SCB
|
||||||
|
self.weight.CB = None
|
||||||
|
self.weight.SCB = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.state.is_training = self.training
|
||||||
|
|
||||||
|
if self.weight.CB is not None:
|
||||||
|
self.init_8bit_state()
|
||||||
|
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
# if self.bias is not None and self.bias.dtype != torch.float16:
|
||||||
|
# self.bias.data = self.bias.data.half()
|
||||||
|
|
||||||
|
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||||
|
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||||
|
|
||||||
|
if not self.state.has_fp16_weights:
|
||||||
|
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||||
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||||
|
# we no longer need the row-major weight
|
||||||
|
del self.state.CB
|
||||||
|
self.weight.data = self.state.CxB
|
||||||
|
elif self.state.memory_efficient_backward and self.state.CxB is not None:
|
||||||
|
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
|
||||||
|
# Thus, we delete CxB from the state.
|
||||||
|
del self.state.CxB
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLtThresh(Linear8bitLt):
|
class Linear8bitLtThresh(Linear8bitLt):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user