Added fp8 simulation layer.

This commit is contained in:
Tim Dettmers 2023-02-13 16:53:07 -08:00
parent c9f505064e
commit 6bdb6c351e
4 changed files with 209 additions and 0 deletions

View File

@ -10,6 +10,7 @@ from .autograd._functions import (
matmul,
matmul_cublas,
mm_cublas,
matmul_fp8
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules

View File

@ -390,6 +390,98 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None
class MatMulFP8(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, fw_code=None, bw_code=None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = state[1]
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
fp8A = F.dequantize_blockwise(cA, state)
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
fp8B = F.dequantize_blockwise(cB, state)
output = torch.nn.functional.linear(fp8A, fp8B)
# 3. Save state
ctx.bw_code = bw_code
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (fp8A, fp8B)
else:
ctx.tensors = (None, None)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
fp8A, B = ctx.tensors
state = ctx.state
grad_A, grad_B, grad_bias = None, None, None
cgrad_out, state = F.quantize_blockwise(grad_ouput, code=ctx.bw_code, blocksize=1024)
fp8out = F.dequantize_blockwise(cgrad_out, state)
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = fp8out.sum(0, dtype=ctx.dtype_bias)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(fp8out, B.t())
if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out)
return grad_A, grad_B, None, grad_bias, None, None
def matmul(
A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
threshold=0.0,
bias=None
):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bias=None):
assert quant_state is not None
return MatMulFP8.apply(A, B, out, bias, fw_code, bw_code)
def matmul(
A: tensor,

View File

@ -343,3 +343,19 @@ class Linear8bitLt(nn.Linear):
del self.state.CxB
return out
class LinearFP8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.bw_code = None
self.fw_code = None
def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = F.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = F.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), bias=self.bias, fw_code=self.fw_code, code=self.bw_code)
return out

View File

@ -429,3 +429,103 @@ def test_matmullt(
if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2)
n = 1
k = 3
dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
funcs = [(torch.matmul, bnb.matmul_fp8)]
str_funcs = ["matmul"]
req_grad = list(product([True, False], repeat=3))
req_grad_str = []
for c in req_grad:
strval = ''
for v in c:
if v == True: strval += 'T'
else: strval += 'F'
req_grad_str.append(strval)
transpose = [(False, True), (False, False)]
str_transpose = ["NT", "NN"]
dtype = [torch.float16, torch.float32]
has_fp16_weights = [True, False]
has_bias = [True, False]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias))
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias))
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names)
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
if has_bias == False:
req_grad = list(req_grad)
req_grad[2] = False
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
bias = None
bias2 = None
if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B)
B2, quant_state = bnb.functional.quantize_fp8(B)
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
if has_bias:
out_torch += bias
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.115
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias1 = bias.grad
bias.grad = None
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if has_bias:
gradBias2 = bias.grad
bias.grad = None
if req_grad[0]:
torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[2]:
torch.testing.assert_allclose(gradBias1, gradBias2)