sim now worse than real

This commit is contained in:
Mitchell Wortsman 2023-02-23 08:27:15 +00:00
parent 7b764d3569
commit 3fbf60ad83
3 changed files with 118 additions and 29 deletions

View File

@ -395,38 +395,41 @@ class MatMulFP8(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None):
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
B_shape = B.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
cB, state = F.quantize(B.float(), code=fw_code)
fp8B = F.dequantize(cB, state).to(B.dtype)
output = torch.matmul(fp8A, fp8B)
# output is half
# 3. Save state
ctx.fw_code = fw_code
ctx.bw_code = bw_code
ctx.bsz = bsz
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (fp8A, fp8B)
# NOTE: we send back A, and re-quant.
ctx.tensors = (A, fp8B)
else:
ctx.tensors = (None, None)
@ -435,30 +438,36 @@ class MatMulFP8(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
req_gradA, req_gradB, _, _, _ = ctx.needs_input_grad
fp8A, B = ctx.tensors
req_gradA, req_gradB, _, _, _, _ = ctx.needs_input_grad
A, B = ctx.tensors
grad_A, grad_B = None, None
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024)
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=1024).to(grad_output.dtype)
# TODO: Fix blocksize to be output_dim
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz)
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz).to(grad_output.dtype)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype)
if req_gradA:
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
if req_gradB:
if fp8A.ndim == 3:
fp8At = fp8A.transpose(2, 1)
elif fp8A.ndim == 2:
fp8At = fp8A.t()
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
At = A.transpose(2, 1).contiguous()
cA, state = F.quantize(At.float(), code=ctx.fw_code)
fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
return grad_A, grad_B, None, None, None
return grad_A, grad_B, None, None, None, None
def matmul(
@ -475,8 +484,8 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None):
return MatMulFP8.apply(A, B, out, fw_code, bw_code)
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1):
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz)
def matmul(

View File

@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2

View File

@ -346,6 +346,68 @@ class Linear8bitLt(nn.Linear):
return out
# Not in use for now...
class Linear8bitLt2(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState()
self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul(x, self.weight, bias=None, state=self.state) + self.bias
#out = torch.matmul(x.half(), W.half().t()) + self.bias
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB
return out
class Linear8bitLtThresh(Linear8bitLt):
def __init__(
self,
@ -363,7 +425,7 @@ class Linear8bitLtThresh(Linear8bitLt):
bias=bias,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
threshold=6.,
index=index
)
@ -372,13 +434,19 @@ class LinearFP8(nn.Linear):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
print('block size is', self.bsz)
def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz)
if self.bias is not None:
out += self.bias
@ -388,27 +456,39 @@ class LinearInt8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
if self.bias is not None:
out += self.bias
return out
# This is 4 bit version.
class LinearInt8Cast(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
for i, k in enumerate(array):
if input_features > array[i + 1]:
self.bsz = k
break
def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
self.code = bnb.functional.create_linear_map(True, 4).to(x.device)
out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
if self.bias is not None:
out += self.bias