Merge branch 'fp8sim' of github.com:TimDettmers/bitsandbytes into fp8sim
This commit is contained in:
commit
5d2e23e8d6
|
@ -395,38 +395,41 @@ class MatMulFP8(torch.autograd.Function):
|
||||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None):
|
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024):
|
||||||
# default of pytorch behavior if inputs are empty
|
# default of pytorch behavior if inputs are empty
|
||||||
ctx.is_empty = False
|
ctx.is_empty = False
|
||||||
if prod(A.shape) == 0:
|
if prod(A.shape) == 0:
|
||||||
ctx.is_empty = True
|
ctx.is_empty = True
|
||||||
ctx.A = A
|
ctx.A = A
|
||||||
ctx.B = B
|
ctx.B = B
|
||||||
|
|
||||||
B_shape = B.shape
|
B_shape = B.shape
|
||||||
if A.shape[-1] == B_shape[0]:
|
if A.shape[-1] == B_shape[0]:
|
||||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||||
else:
|
else:
|
||||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
|
|
||||||
# 1. Dequantize
|
# 1. Dequantize
|
||||||
# 2. MatmulnN
|
# 2. MatmulnN
|
||||||
|
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
|
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
|
|
||||||
|
|
||||||
cB, state = F.quantize(B.float(), code=fw_code)
|
cB, state = F.quantize(B.float(), code=fw_code)
|
||||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||||
|
|
||||||
output = torch.matmul(fp8A, fp8B)
|
output = torch.matmul(fp8A, fp8B)
|
||||||
|
|
||||||
|
# output is half
|
||||||
|
|
||||||
# 3. Save state
|
# 3. Save state
|
||||||
|
ctx.fw_code = fw_code
|
||||||
ctx.bw_code = bw_code
|
ctx.bw_code = bw_code
|
||||||
|
ctx.bsz = bsz
|
||||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||||
|
|
||||||
if any(ctx.needs_input_grad[:2]):
|
if any(ctx.needs_input_grad[:2]):
|
||||||
ctx.tensors = (fp8A, fp8B)
|
# NOTE: we send back A, and re-quant.
|
||||||
|
ctx.tensors = (A, fp8B)
|
||||||
else:
|
else:
|
||||||
ctx.tensors = (None, None)
|
ctx.tensors = (None, None)
|
||||||
|
|
||||||
|
@ -435,30 +438,36 @@ class MatMulFP8(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
if ctx.is_empty:
|
if ctx.is_empty:
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||||
|
|
||||||
req_gradA, req_gradB, _, _, _ = ctx.needs_input_grad
|
req_gradA, req_gradB, _, _, _, _ = ctx.needs_input_grad
|
||||||
fp8A, B = ctx.tensors
|
A, B = ctx.tensors
|
||||||
|
|
||||||
grad_A, grad_B = None, None
|
grad_A, grad_B = None, None
|
||||||
|
|
||||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024)
|
# TODO: Fix blocksize to be output_dim
|
||||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=1024).to(grad_output.dtype)
|
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz)
|
||||||
|
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz).to(grad_output.dtype)
|
||||||
|
|
||||||
# Cast grad_output to fp16
|
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||||
if len(grad_output.shape) == 3:
|
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
|
||||||
|
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||||
|
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||||
|
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||||
|
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||||
|
|
||||||
# not supported by PyTorch. TODO: create work-around
|
# not supported by PyTorch. TODO: create work-around
|
||||||
if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype)
|
if req_gradA:
|
||||||
if req_gradB:
|
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||||
if fp8A.ndim == 3:
|
|
||||||
fp8At = fp8A.transpose(2, 1)
|
|
||||||
elif fp8A.ndim == 2:
|
|
||||||
fp8At = fp8A.t()
|
|
||||||
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
|
||||||
|
|
||||||
return grad_A, grad_B, None, None, None
|
if req_gradB:
|
||||||
|
At = A.transpose(2, 1).contiguous()
|
||||||
|
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||||
|
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||||
|
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
|
||||||
|
|
||||||
|
return grad_A, grad_B, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class MatMul8bitMixed(torch.autograd.Function):
|
class MatMul8bitMixed(torch.autograd.Function):
|
||||||
|
@ -659,8 +668,8 @@ def matmul(
|
||||||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||||
|
|
||||||
|
|
||||||
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None):
|
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1):
|
||||||
return MatMulFP8.apply(A, B, out, fw_code, bw_code)
|
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz)
|
||||||
|
|
||||||
|
|
||||||
def matmul_mixed(
|
def matmul_mixed(
|
||||||
|
|
|
@ -2,4 +2,4 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2
|
||||||
|
|
|
@ -346,6 +346,68 @@ class Linear8bitLt(nn.Linear):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# Not in use for now...
|
||||||
|
class Linear8bitLt2(nn.Linear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_features,
|
||||||
|
output_features,
|
||||||
|
bias=True,
|
||||||
|
has_fp16_weights=True,
|
||||||
|
memory_efficient_backward=False,
|
||||||
|
threshold=0.0,
|
||||||
|
index=None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
input_features, output_features, bias
|
||||||
|
)
|
||||||
|
self.state = bnb.MatmulLtState()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
self.state.threshold = threshold
|
||||||
|
self.state.has_fp16_weights = has_fp16_weights
|
||||||
|
self.state.memory_efficient_backward = memory_efficient_backward
|
||||||
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
|
self.state.use_pool = True
|
||||||
|
|
||||||
|
self.weight = Int8Params(
|
||||||
|
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_8bit_state(self):
|
||||||
|
self.state.CB = self.weight.CB
|
||||||
|
self.state.SCB = self.weight.SCB
|
||||||
|
self.weight.CB = None
|
||||||
|
self.weight.SCB = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.state.is_training = self.training
|
||||||
|
|
||||||
|
if self.weight.CB is not None:
|
||||||
|
self.init_8bit_state()
|
||||||
|
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
# if self.bias is not None and self.bias.dtype != torch.float16:
|
||||||
|
# self.bias.data = self.bias.data.half()
|
||||||
|
|
||||||
|
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||||
|
out = bnb.matmul(x, self.weight, bias=None, state=self.state) + self.bias
|
||||||
|
#out = torch.matmul(x.half(), W.half().t()) + self.bias
|
||||||
|
|
||||||
|
if not self.state.has_fp16_weights:
|
||||||
|
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||||
|
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||||
|
# we no longer need the row-major weight
|
||||||
|
del self.state.CB
|
||||||
|
self.weight.data = self.state.CxB
|
||||||
|
elif self.state.memory_efficient_backward and self.state.CxB is not None:
|
||||||
|
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
|
||||||
|
# Thus, we delete CxB from the state.
|
||||||
|
del self.state.CxB
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLtThresh(Linear8bitLt):
|
class Linear8bitLtThresh(Linear8bitLt):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -363,7 +425,7 @@ class Linear8bitLtThresh(Linear8bitLt):
|
||||||
bias=bias,
|
bias=bias,
|
||||||
has_fp16_weights=has_fp16_weights,
|
has_fp16_weights=has_fp16_weights,
|
||||||
memory_efficient_backward=memory_efficient_backward,
|
memory_efficient_backward=memory_efficient_backward,
|
||||||
threshold=threshold,
|
threshold=6.,
|
||||||
index=index
|
index=index
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -372,13 +434,19 @@ class LinearFP8(nn.Linear):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.bw_code = None
|
self.bw_code = None
|
||||||
self.fw_code = None
|
self.fw_code = None
|
||||||
|
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if input_features > array[i + 1]:
|
||||||
|
self.bsz = k
|
||||||
|
break
|
||||||
|
print('block size is', self.bsz)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
if self.fw_code is None:
|
if self.fw_code is None:
|
||||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||||
|
|
||||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code)
|
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
out += self.bias
|
out += self.bias
|
||||||
|
|
||||||
|
@ -388,27 +456,39 @@ class LinearInt8(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.code = None
|
self.code = None
|
||||||
|
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if input_features > array[i + 1]:
|
||||||
|
self.bsz = k
|
||||||
|
break
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
if self.code is None:
|
if self.code is None:
|
||||||
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
||||||
|
|
||||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code)
|
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
out += self.bias
|
out += self.bias
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
# This is 4 bit version.
|
||||||
class LinearInt8Cast(nn.Linear):
|
class LinearInt8Cast(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
self.code = None
|
self.code = None
|
||||||
|
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if input_features > array[i + 1]:
|
||||||
|
self.bsz = k
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
if self.code is None:
|
if self.code is None:
|
||||||
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
self.code = bnb.functional.create_linear_map(True, 4).to(x.device)
|
||||||
|
|
||||||
out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code)
|
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
out += self.bias
|
out += self.bias
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user