Forward matmul_fp4 tests pass.
This commit is contained in:
parent
3ac5840c03
commit
160a83580d
|
@ -10,6 +10,7 @@ from .autograd._functions import (
|
||||||
matmul,
|
matmul,
|
||||||
matmul_cublas,
|
matmul_cublas,
|
||||||
mm_cublas,
|
mm_cublas,
|
||||||
|
matmul_fp4
|
||||||
)
|
)
|
||||||
from .cextension import COMPILED_WITH_CUDA
|
from .cextension import COMPILED_WITH_CUDA
|
||||||
from .nn import modules
|
from .nn import modules
|
||||||
|
|
|
@ -2,7 +2,7 @@ import operator
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import reduce # Required in Python 3
|
from functools import reduce # Required in Python 3
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -474,6 +474,67 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulFP4(torch.autograd.Function):
|
||||||
|
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||||
|
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, A, B, out=None, bias=None, state=None):
|
||||||
|
# default of pytorch behavior if inputs are empty
|
||||||
|
ctx.is_empty = False
|
||||||
|
if prod(A.shape) == 0:
|
||||||
|
ctx.is_empty = True
|
||||||
|
ctx.A = A
|
||||||
|
ctx.B = B
|
||||||
|
ctx.bias = bias
|
||||||
|
B_shape = state[1]
|
||||||
|
if A.shape[-1] == B_shape[0]:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||||
|
else:
|
||||||
|
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
|
|
||||||
|
# 1. Dequantize
|
||||||
|
# 2. Matmul
|
||||||
|
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
|
||||||
|
|
||||||
|
# 3. Save state
|
||||||
|
ctx.state = state
|
||||||
|
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||||
|
|
||||||
|
if any(ctx.needs_input_grad[:2]):
|
||||||
|
ctx.tensors = A
|
||||||
|
else:
|
||||||
|
ctx.tensors = [None, None]
|
||||||
|
ctx.tensor_states = (None, None)
|
||||||
|
ctx.save_for_backward(None, None)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.is_empty:
|
||||||
|
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
||||||
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||||
|
|
||||||
|
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||||
|
A = ctx.tensors
|
||||||
|
state = ctx.state
|
||||||
|
|
||||||
|
if req_gradBias:
|
||||||
|
# compute grad_bias first before changing grad_output dtype
|
||||||
|
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||||
|
|
||||||
|
# Cast grad_output to fp16
|
||||||
|
if len(grad_output.shape) == 3:
|
||||||
|
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||||
|
|
||||||
|
if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||||
|
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A))
|
||||||
|
|
||||||
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
||||||
def matmul(
|
def matmul(
|
||||||
A: tensor,
|
A: tensor,
|
||||||
B: tensor,
|
B: tensor,
|
||||||
|
@ -486,3 +547,7 @@ def matmul(
|
||||||
if threshold > 0.0:
|
if threshold > 0.0:
|
||||||
state.threshold = threshold
|
state.threshold = threshold
|
||||||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None):
|
||||||
|
return MatMulFP4.apply(A, B, out, bias, quant_state)
|
||||||
|
|
|
@ -626,7 +626,7 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
||||||
-------
|
-------
|
||||||
torch.Tensor:
|
torch.Tensor:
|
||||||
The 8-bit tensor with packed 4-bit values.
|
The 8-bit tensor with packed 4-bit values.
|
||||||
tuple(torch.Tensor, torch.Size, torch.dtype):
|
tuple(torch.Tensor, torch.Size, torch.dtype, int):
|
||||||
The quantization state to undo the quantization.
|
The quantization state to undo the quantization.
|
||||||
"""
|
"""
|
||||||
if A.device.type != 'cuda':
|
if A.device.type != 'cuda':
|
||||||
|
@ -640,10 +640,10 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
|
||||||
blocks += 1 if n % blocksize > 0 else 0
|
blocks += 1 if n % blocksize > 0 else 0
|
||||||
absmax = torch.zeros((blocks,), device=A.device)
|
absmax = torch.zeros((blocks,), device=A.device)
|
||||||
|
|
||||||
state = (absmax, input_shape, A.dtype)
|
state = (absmax, input_shape, A.dtype, blocksize)
|
||||||
|
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.zeros(((n+1)//2,), dtype=torch.uint8, device=A.device)
|
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
|
||||||
|
|
||||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||||
|
|
||||||
|
@ -692,7 +692,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
||||||
shape = out.shape
|
shape = out.shape
|
||||||
dtype = out.dtype
|
dtype = out.dtype
|
||||||
else:
|
else:
|
||||||
absmax, shape, dtype = quant_state
|
absmax, shape, dtype, blocksize = quant_state
|
||||||
|
|
||||||
|
|
||||||
if out is None:
|
if out is None:
|
||||||
|
@ -700,6 +700,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
||||||
|
|
||||||
n = out.numel()
|
n = out.numel()
|
||||||
|
|
||||||
|
|
||||||
device = pre_call(A.device)
|
device = pre_call(A.device)
|
||||||
is_on_gpu([A, absmax, out])
|
is_on_gpu([A, absmax, out])
|
||||||
if out.dtype == torch.float32:
|
if out.dtype == torch.float32:
|
||||||
|
@ -710,9 +711,9 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
|
||||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||||
post_call(A.device)
|
post_call(A.device)
|
||||||
|
|
||||||
return out
|
is_transposed = (True if A.shape[0] == 1 else False)
|
||||||
|
if is_transposed: return out.t()
|
||||||
|
else: return out
|
||||||
|
|
||||||
|
|
||||||
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
|
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
|
||||||
|
|
|
@ -133,6 +133,67 @@ class Embedding(torch.nn.Embedding):
|
||||||
|
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
class FP4Params(torch.nn.Parameter):
|
||||||
|
def __new__(cls, data=None, requires_grad=True, quant_state=None):
|
||||||
|
cls.quant_state = None
|
||||||
|
if data is None:
|
||||||
|
data = torch.empty(0)
|
||||||
|
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||||
|
|
||||||
|
def cuda(self, device):
|
||||||
|
w = self.data.contiguous().half().cuda(device)
|
||||||
|
w_fp4, quant_state = bnb.functional.quantize_fp4(w)
|
||||||
|
self.data = w_fp4
|
||||||
|
self.quant_state = quant_state
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
|
|
||||||
|
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
|
||||||
|
return self.cuda(device)
|
||||||
|
else:
|
||||||
|
new_param = FP4Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
|
||||||
|
requires_grad=self.requires_grad, quant_state=self.quant_state)
|
||||||
|
|
||||||
|
return new_param
|
||||||
|
|
||||||
|
|
||||||
|
class LinearFP4(nn.Linear):
|
||||||
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.state = bnb.MatmulLtState()
|
||||||
|
self.weight = FP4Params(self.weight.data, requires_grad=False)
|
||||||
|
|
||||||
|
def init_8bit_state(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
self.state.is_training = self.training
|
||||||
|
|
||||||
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
if getattr(self.weight, 'state', None) is None:
|
||||||
|
print('FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
|
||||||
|
out = bnb.matmul_fp(x, self.weight, bias=self.bias, state=self.weight.state)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Int8Params(torch.nn.Parameter):
|
class Int8Params(torch.nn.Parameter):
|
||||||
def __new__(
|
def __new__(
|
||||||
|
@ -208,6 +269,7 @@ class Int8Params(torch.nn.Parameter):
|
||||||
return new_param
|
return new_param
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Linear):
|
class Linear8bitLt(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||||
|
|
|
@ -429,3 +429,118 @@ def test_matmullt(
|
||||||
|
|
||||||
if req_grad[2]:
|
if req_grad[2]:
|
||||||
torch.testing.assert_allclose(gradBias1, gradBias2)
|
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||||
|
|
||||||
|
|
||||||
|
n = 1
|
||||||
|
k = 3
|
||||||
|
dim1 = torch.randint(16, 64, size=(n,)).tolist()
|
||||||
|
dim2 = torch.randint(32, 96, size=(n,)).tolist()
|
||||||
|
dim3 = torch.randint(32, 96, size=(n,)).tolist()
|
||||||
|
dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
||||||
|
|
||||||
|
dim2.append(0)
|
||||||
|
|
||||||
|
funcs = [(torch.matmul, bnb.matmul_fp4)]
|
||||||
|
str_funcs = ["matmul"]
|
||||||
|
req_grad = list(product([True, False], repeat=3))
|
||||||
|
req_grad_str = []
|
||||||
|
for c in req_grad:
|
||||||
|
strval = ''
|
||||||
|
for v in c:
|
||||||
|
if v == True: strval += 'T'
|
||||||
|
else: strval += 'F'
|
||||||
|
req_grad_str.append(strval)
|
||||||
|
|
||||||
|
transpose = [(False, True), (False, False)]
|
||||||
|
str_transpose = ["NT", "NN"]
|
||||||
|
dtype = [torch.float16, torch.float32]
|
||||||
|
has_fp16_weights = [True, False]
|
||||||
|
has_bias = [True, False]
|
||||||
|
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias))
|
||||||
|
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias))
|
||||||
|
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}".format(*vals) for vals in str_values]
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||||
|
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias", values, ids=names)
|
||||||
|
def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias):
|
||||||
|
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||||
|
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||||
|
if has_bias == False:
|
||||||
|
req_grad = list(req_grad)
|
||||||
|
req_grad[2] = False
|
||||||
|
|
||||||
|
for i in range(k):
|
||||||
|
# normal multiply
|
||||||
|
if funcs[0] in [torch.mm, torch.matmul]:
|
||||||
|
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
|
||||||
|
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||||
|
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||||
|
bias = None
|
||||||
|
bias2 = None
|
||||||
|
if has_bias:
|
||||||
|
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
|
||||||
|
bias2 = bias.clone()
|
||||||
|
torch.nn.init.xavier_uniform_(B)
|
||||||
|
B2 = B.clone()
|
||||||
|
|
||||||
|
B2, quant_state = bnb.functional.quantize_fp4(B)
|
||||||
|
|
||||||
|
if not transpose[0] and transpose[1]:
|
||||||
|
out_torch = funcs[0](A, B.t())
|
||||||
|
out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2)
|
||||||
|
elif not transpose[0] and not transpose[1]:
|
||||||
|
out_torch = funcs[0](A, B)
|
||||||
|
out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2)
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
out_torch += bias
|
||||||
|
|
||||||
|
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
|
||||||
|
|
||||||
|
n = out_bnb.numel()
|
||||||
|
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
||||||
|
if n > 0:
|
||||||
|
assert err < 0.11
|
||||||
|
|
||||||
|
if any(req_grad):
|
||||||
|
out_bnb.data.copy_(out_torch)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||||
|
loss_bnb.backward()
|
||||||
|
gradA1 = A.grad
|
||||||
|
gradB1 = B.grad
|
||||||
|
A.grad = None
|
||||||
|
B.grad = None
|
||||||
|
if has_bias:
|
||||||
|
gradBias1 = bias.grad
|
||||||
|
bias.grad = None
|
||||||
|
|
||||||
|
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
|
||||||
|
loss_torch.backward()
|
||||||
|
gradA2 = A.grad
|
||||||
|
gradB2 = B.grad
|
||||||
|
A.grad = None
|
||||||
|
B.grad = None
|
||||||
|
if has_bias:
|
||||||
|
gradBias2 = bias.grad
|
||||||
|
bias.grad = None
|
||||||
|
|
||||||
|
if req_grad[0]:
|
||||||
|
torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||||
|
if req_grad[1]:
|
||||||
|
n = gradB1.numel()
|
||||||
|
if dim2 > 0:
|
||||||
|
assert torch.abs(gradB1).sum() > 0.0
|
||||||
|
assert torch.abs(gradB2).sum() > 0.0
|
||||||
|
else:
|
||||||
|
assert torch.abs(gradB1).sum() == 0.0
|
||||||
|
assert torch.abs(gradB2).sum() == 0.0
|
||||||
|
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||||
|
|
||||||
|
assert (idx == 0).sum().item() <= n * 0.1
|
||||||
|
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||||
|
assert (idx == 0).sum().item() <= n * 0.02
|
||||||
|
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3
|
||||||
|
)
|
||||||
|
|
||||||
|
if req_grad[2]:
|
||||||
|
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||||
|
|
|
@ -2221,26 +2221,13 @@ def test_fp4_quant():
|
||||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||||
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
||||||
A2 = F.dequantize_fp4(qa, SA)
|
A2 = F.dequantize_fp4(qa, SA)
|
||||||
#qa, SA = F.quantize_fp4(A1, blocksize=128)
|
|
||||||
#A2 = F.dequantize_fp4(qa, SA, blocksize=128)
|
|
||||||
|
|
||||||
#A1 = A1.flatten().sort()[0]
|
|
||||||
#A2 = A2.flatten().sort()[0]
|
|
||||||
|
|
||||||
#print(A1)
|
|
||||||
#print(A2)
|
|
||||||
|
|
||||||
err = (A1 - A2).abs().float()
|
err = (A1 - A2).abs().float()
|
||||||
relerr = (err/A1.abs().float()).mean()
|
relerr = (err/A1.abs().float()).mean()
|
||||||
err = err.mean()
|
err = err.mean()
|
||||||
|
|
||||||
print(err, relerr)
|
assert err.item() < 0.1
|
||||||
|
assert relerr.item() < 0.28
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#assert err.item() < 0.1
|
|
||||||
#assert relerr.item() < 0.28
|
|
||||||
|
|
||||||
|
|
||||||
def test_bench_fp4_dequant():
|
def test_bench_fp4_dequant():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user