From 7b764d35698eb77f20768e3f62b0e53f3044fb5f Mon Sep 17 00:00:00 2001 From: Mitchell Wortsman Date: Tue, 21 Feb 2023 03:53:44 +0000 Subject: [PATCH] adding half() cast --- bitsandbytes/autograd/_functions.py | 14 ++++--- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 59 +++++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index b8b2dbc..aa50b21 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -415,8 +415,8 @@ class MatMulFP8(torch.autograd.Function): cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024) fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype) - cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024) - fp8B = F.dequantize_blockwise(cB, state, blocksize=1024).to(B.dtype) + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) output = torch.matmul(fp8A, fp8B) @@ -450,9 +450,13 @@ class MatMulFP8(torch.autograd.Function): grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() # not supported by PyTorch. TODO: create work-around - #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(fp8out, B.t()) - if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out) + if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype) + if req_gradB: + if fp8A.ndim == 3: + fp8At = fp8A.transpose(2, 1) + elif fp8A.ndim == 2: + fp8At = fp8A.t() + grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype) return grad_A, grad_B, None, None, None diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 7c2b552..ae9eb8c 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8 +from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index c8a3ecc..23f391a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -326,10 +326,11 @@ class Linear8bitLt(nn.Linear): self.init_8bit_state() # weights are cast automatically as Int8Params, but the bias has to be cast manually - if self.bias is not None and self.bias.dtype != torch.float16: - self.bias.data = self.bias.data.half() + # if self.bias is not None and self.bias.dtype != torch.float16: + # self.bias.data = self.bias.data.half() - out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + #out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias + out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias if not self.state.has_fp16_weights: if not self.state.memory_efficient_backward and self.state.CB is not None: @@ -344,6 +345,28 @@ class Linear8bitLt(nn.Linear): return out + +class Linear8bitLtThresh(Linear8bitLt): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=6.0, + index=None, + ): + super().__init__( + input_features, + output_features, + bias=bias, + has_fp16_weights=has_fp16_weights, + memory_efficient_backward=memory_efficient_backward, + threshold=threshold, + index=index + ) + class LinearFP8(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) @@ -361,3 +384,33 @@ class LinearFP8(nn.Linear): return out +class LinearInt8(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.code = None + + def forward(self, x: torch.Tensor): + if self.code is None: + self.code = bnb.functional.create_linear_map(True, 8).to(x.device) + + out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code) + if self.bias is not None: + out += self.bias + + return out + +class LinearInt8Cast(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.code = None + + def forward(self, x: torch.Tensor): + if self.code is None: + self.code = bnb.functional.create_linear_map(True, 8).to(x.device) + + out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code) + if self.bias is not None: + out += self.bias + + return out +