From 3fbf60ad83e845677e77c807b884393f25f40c8e Mon Sep 17 00:00:00 2001 From: Mitchell Wortsman Date: Thu, 23 Feb 2023 08:27:15 +0000 Subject: [PATCH] sim now worse than real --- bitsandbytes/autograd/_functions.py | 55 ++++++++++-------- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 90 +++++++++++++++++++++++++++-- 3 files changed, 118 insertions(+), 29 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index aa50b21..6de595e 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -395,38 +395,41 @@ class MatMulFP8(torch.autograd.Function): # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod - def forward(ctx, A, B, out=None, fw_code=None, bw_code=None): + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024): # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: ctx.is_empty = True ctx.A = A ctx.B = B + B_shape = B.shape if A.shape[-1] == B_shape[0]: return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) else: return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) - # 1. Dequantize # 2. MatmulnN - - cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024) - fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype) + cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) + fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) cB, state = F.quantize(B.float(), code=fw_code) fp8B = F.dequantize(cB, state).to(B.dtype) output = torch.matmul(fp8A, fp8B) + # output is half # 3. Save state + ctx.fw_code = fw_code ctx.bw_code = bw_code + ctx.bsz = bsz ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype if any(ctx.needs_input_grad[:2]): - ctx.tensors = (fp8A, fp8B) + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) else: ctx.tensors = (None, None) @@ -435,30 +438,36 @@ class MatMulFP8(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None - req_gradA, req_gradB, _, _, _ = ctx.needs_input_grad - fp8A, B = ctx.tensors + req_gradA, req_gradB, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors grad_A, grad_B = None, None - cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024) - fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=1024).to(grad_output.dtype) + # TODO: Fix blocksize to be output_dim + cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz) + fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz).to(grad_output.dtype) - # Cast grad_output to fp16 - if len(grad_output.shape) == 3: - grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) # not supported by PyTorch. TODO: create work-around - if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype) + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + if req_gradB: - if fp8A.ndim == 3: - fp8At = fp8A.transpose(2, 1) - elif fp8A.ndim == 2: - fp8At = fp8A.t() - grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype) + At = A.transpose(2, 1).contiguous() + cA, state = F.quantize(At.float(), code=ctx.fw_code) + fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype) - return grad_A, grad_B, None, None, None + return grad_A, grad_B, None, None, None, None def matmul( @@ -475,8 +484,8 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None): - return MatMulFP8.apply(A, B, out, fw_code, bw_code) +def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1): + return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz) def matmul( diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index ae9eb8c..9c70642 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast +from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 23f391a..5c0d0d4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -346,6 +346,68 @@ class Linear8bitLt(nn.Linear): return out +# Not in use for now... +class Linear8bitLt2(nn.Linear): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__( + input_features, output_features, bias + ) + self.state = bnb.MatmulLtState() + self.index = index + + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x): + self.state.is_training = self.training + + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + # if self.bias is not None and self.bias.dtype != torch.float16: + # self.bias.data = self.bias.data.half() + + #out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias + out = bnb.matmul(x, self.weight, bias=None, state=self.state) + self.bias + #out = torch.matmul(x.half(), W.half().t()) + self.bias + + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB + + return out + + class Linear8bitLtThresh(Linear8bitLt): def __init__( self, @@ -363,7 +425,7 @@ class Linear8bitLtThresh(Linear8bitLt): bias=bias, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, - threshold=threshold, + threshold=6., index=index ) @@ -372,13 +434,19 @@ class LinearFP8(nn.Linear): super().__init__(input_features, output_features, bias) self.bw_code = None self.fw_code = None + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + for i, k in enumerate(array): + if input_features > array[i + 1]: + self.bsz = k + break + print('block size is', self.bsz) def forward(self, x: torch.Tensor): if self.fw_code is None: self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) - out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code) + out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz) if self.bias is not None: out += self.bias @@ -388,27 +456,39 @@ class LinearInt8(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) self.code = None + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + for i, k in enumerate(array): + if input_features > array[i + 1]: + self.bsz = k + break def forward(self, x: torch.Tensor): if self.code is None: self.code = bnb.functional.create_linear_map(True, 8).to(x.device) - out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code) + out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz) if self.bias is not None: out += self.bias return out +# This is 4 bit version. class LinearInt8Cast(nn.Linear): def __init__(self, input_features, output_features, bias=True): super().__init__(input_features, output_features, bias) self.code = None + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + for i, k in enumerate(array): + if input_features > array[i + 1]: + self.bsz = k + break + def forward(self, x: torch.Tensor): if self.code is None: - self.code = bnb.functional.create_linear_map(True, 8).to(x.device) + self.code = bnb.functional.create_linear_map(True, 4).to(x.device) - out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code) + out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz) if self.bias is not None: out += self.bias