Merge branch 'fp8sim' of github.com:TimDettmers/bitsandbytes into fp8sim

This commit is contained in:
Tim Dettmers 2023-02-23 10:56:49 -08:00
commit 5d2e23e8d6
3 changed files with 119 additions and 30 deletions

View File

@ -395,38 +395,41 @@ class MatMulFP8(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None): def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024):
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
ctx.is_empty = True ctx.is_empty = True
ctx.A = A ctx.A = A
ctx.B = B ctx.B = B
B_shape = B.shape B_shape = B.shape
if A.shape[-1] == B_shape[0]: if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else: else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize # 1. Dequantize
# 2. MatmulnN # 2. MatmulnN
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024) fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
cB, state = F.quantize(B.float(), code=fw_code) cB, state = F.quantize(B.float(), code=fw_code)
fp8B = F.dequantize(cB, state).to(B.dtype) fp8B = F.dequantize(cB, state).to(B.dtype)
output = torch.matmul(fp8A, fp8B) output = torch.matmul(fp8A, fp8B)
# output is half
# 3. Save state # 3. Save state
ctx.fw_code = fw_code
ctx.bw_code = bw_code ctx.bw_code = bw_code
ctx.bsz = bsz
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
ctx.tensors = (fp8A, fp8B) # NOTE: we send back A, and re-quant.
ctx.tensors = (A, fp8B)
else: else:
ctx.tensors = (None, None) ctx.tensors = (None, None)
@ -435,30 +438,36 @@ class MatMulFP8(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty: if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
req_gradA, req_gradB, _, _, _ = ctx.needs_input_grad req_gradA, req_gradB, _, _, _, _ = ctx.needs_input_grad
fp8A, B = ctx.tensors A, B = ctx.tensors
grad_A, grad_B = None, None grad_A, grad_B = None, None
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024) # TODO: Fix blocksize to be output_dim
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=1024).to(grad_output.dtype) cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz)
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz).to(grad_output.dtype)
# Cast grad_output to fp16 cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
if len(grad_output.shape) == 3: fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around # not supported by PyTorch. TODO: create work-around
if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype) if req_gradA:
if req_gradB: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
if fp8A.ndim == 3:
fp8At = fp8A.transpose(2, 1)
elif fp8A.ndim == 2:
fp8At = fp8A.t()
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
return grad_A, grad_B, None, None, None if req_gradB:
At = A.transpose(2, 1).contiguous()
cA, state = F.quantize(At.float(), code=ctx.fw_code)
fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
return grad_A, grad_B, None, None, None, None
class MatMul8bitMixed(torch.autograd.Function): class MatMul8bitMixed(torch.autograd.Function):
@ -659,8 +668,8 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None): def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1):
return MatMulFP8.apply(A, B, out, fw_code, bw_code) return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz)
def matmul_mixed( def matmul_mixed(

View File

@ -2,4 +2,4 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2

View File

@ -346,6 +346,68 @@ class Linear8bitLt(nn.Linear):
return out return out
# Not in use for now...
class Linear8bitLt2(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState()
self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul(x, self.weight, bias=None, state=self.state) + self.bias
#out = torch.matmul(x.half(), W.half().t()) + self.bias
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB
return out
class Linear8bitLtThresh(Linear8bitLt): class Linear8bitLtThresh(Linear8bitLt):
def __init__( def __init__(
self, self,
@ -363,7 +425,7 @@ class Linear8bitLtThresh(Linear8bitLt):
bias=bias, bias=bias,
has_fp16_weights=has_fp16_weights, has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward, memory_efficient_backward=memory_efficient_backward,
threshold=threshold, threshold=6.,
index=index index=index
) )
@ -372,13 +434,19 @@ class LinearFP8(nn.Linear):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.bw_code = None self.bw_code = None
self.fw_code = None self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
print('block size is', self.bsz)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.fw_code is None: if self.fw_code is None:
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code) out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz)
if self.bias is not None: if self.bias is not None:
out += self.bias out += self.bias
@ -388,27 +456,39 @@ class LinearInt8(nn.Linear):
def __init__(self, input_features, output_features, bias=True): def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.code = None self.code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.code is None: if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device) self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code) out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
if self.bias is not None: if self.bias is not None:
out += self.bias out += self.bias
return out return out
# This is 4 bit version.
class LinearInt8Cast(nn.Linear): class LinearInt8Cast(nn.Linear):
def __init__(self, input_features, output_features, bias=True): def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias) super().__init__(input_features, output_features, bias)
self.code = None self.code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.code is None: if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device) self.code = bnb.functional.create_linear_map(True, 4).to(x.device)
out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code) out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
if self.bias is not None: if self.bias is not None:
out += self.bias out += self.bias