Added forward/backward tests; removed bias.
This commit is contained in:
parent
6bdb6c351e
commit
ca3236587a
|
@ -395,15 +395,14 @@ class MatMulFP8(torch.autograd.Function):
|
|||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, fw_code=None, bw_code=None):
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
ctx.bias = bias
|
||||
B_shape = state[1]
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
|
@ -414,17 +413,17 @@ class MatMulFP8(torch.autograd.Function):
|
|||
# 2. MatmulnN
|
||||
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
|
||||
fp8A = F.dequantize_blockwise(cA, state)
|
||||
fp8A = F.dequantize_blockwise(cA, state).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
|
||||
fp8B = F.dequantize_blockwise(cB, state)
|
||||
fp8B = F.dequantize_blockwise(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.nn.functional.linear(fp8A, fp8B)
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
|
||||
# 3. Save state
|
||||
ctx.bw_code = bw_code
|
||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (fp8A, fp8B)
|
||||
|
@ -436,21 +435,15 @@ class MatMulFP8(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None
|
||||
|
||||
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
|
||||
req_gradA, req_gradB, _, _, _ = ctx.needs_input_grad
|
||||
fp8A, B = ctx.tensors
|
||||
state = ctx.state
|
||||
|
||||
grad_A, grad_B, grad_bias = None, None, None
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
cgrad_out, state = F.quantize_blockwise(grad_ouput, code=ctx.bw_code, blocksize=1024)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state)
|
||||
|
||||
if req_gradBias:
|
||||
# compute grad_bias first before changing grad_output dtype
|
||||
grad_bias = fp8out.sum(0, dtype=ctx.dtype_bias)
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=1024)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state).to(grad_output.dtype)
|
||||
|
||||
# Cast grad_output to fp16
|
||||
if len(grad_output.shape) == 3:
|
||||
|
@ -461,7 +454,7 @@ class MatMulFP8(torch.autograd.Function):
|
|||
if req_gradA: grad_A = torch.matmul(fp8out, B.t())
|
||||
if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out)
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None, None
|
||||
return grad_A, grad_B, None, None, None
|
||||
|
||||
|
||||
def matmul(
|
||||
|
@ -478,9 +471,8 @@ def matmul(
|
|||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||
|
||||
|
||||
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bias=None):
|
||||
assert quant_state is not None
|
||||
return MatMulFP8.apply(A, B, out, bias, fw_code, bw_code)
|
||||
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None):
|
||||
return MatMulFP8.apply(A, B, out, fw_code, bw_code)
|
||||
|
||||
|
||||
def matmul(
|
||||
|
|
|
@ -355,7 +355,9 @@ class LinearFP8(nn.Linear):
|
|||
self.bw_code = F.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = F.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), bias=self.bias, fw_code=self.fw_code, code=self.bw_code)
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, code=self.bw_code)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
|
|
|
@ -456,18 +456,16 @@ transpose = [(False, True), (False, False)]
|
|||
str_transpose = ["NT", "NN"]
|
||||
dtype = [torch.float16, torch.float32]
|
||||
has_fp16_weights = [True, False]
|
||||
has_bias = [True, False]
|
||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias))
|
||||
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias))
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values]
|
||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
|
||||
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names)
|
||||
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias):
|
||||
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
|
||||
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
if has_bias == False:
|
||||
req_grad = list(req_grad)
|
||||
req_grad[2] = False
|
||||
req_grad = list(req_grad)
|
||||
req_grad[2] = False
|
||||
|
||||
for i in range(k):
|
||||
# normal multiply
|
||||
|
@ -475,32 +473,24 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
|
||||
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
bias = None
|
||||
bias2 = None
|
||||
if has_bias:
|
||||
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
|
||||
bias2 = bias.clone()
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
B2, quant_state = bnb.functional.quantize_fp8(B)
|
||||
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
|
||||
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
|
||||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
||||
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
|
||||
elif not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
||||
|
||||
if has_bias:
|
||||
out_torch += bias
|
||||
out_bnb = funcs[1](A, B, fw_code, bw_code)
|
||||
|
||||
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
|
||||
|
||||
n = out_bnb.numel()
|
||||
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
||||
if n > 0:
|
||||
assert err < 0.115
|
||||
|
||||
assert err < 0.20
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -510,9 +500,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
if has_bias:
|
||||
gradBias1 = bias.grad
|
||||
bias.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
|
||||
loss_torch.backward()
|
||||
|
@ -520,12 +507,26 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
if has_bias:
|
||||
gradBias2 = bias.grad
|
||||
bias.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
|
||||
if req_grad[2]:
|
||||
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
if dim2 > 0:
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
else:
|
||||
assert torch.abs(gradB1).sum() == 0.0
|
||||
assert torch.abs(gradB2).sum() == 0.0
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
|
||||
assert (idx == 0).sum().item() <= n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() <= n * 0.02
|
||||
grad_err = (gradB1-gradB2).abs().mean()
|
||||
assert grad_err.item() < 0.003
|
||||
torch.testing.assert_allclose(
|
||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user