Added bias test for LinearFP4 and basic test.
This commit is contained in:
parent
c361f84239
commit
c0c352b379
|
@ -2,4 +2,4 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4
|
||||||
|
|
|
@ -188,9 +188,9 @@ class LinearFP4(nn.Linear):
|
||||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
if getattr(self.weight, 'state', None) is None:
|
if getattr(self.weight, 'quant_state', None) is None:
|
||||||
print('FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
|
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
|
||||||
out = bnb.matmul_fp(x, self.weight, bias=self.bias, state=self.weight.state)
|
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
|
@ -330,12 +330,8 @@ def test_linear8bitlt_inference(threshold):
|
||||||
|
|
||||||
|
|
||||||
def test_linear8bitlt_accumulated_gradient():
|
def test_linear8bitlt_accumulated_gradient():
|
||||||
l1 = torch.nn.Sequential(
|
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
|
||||||
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
|
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
|
||||||
)
|
|
||||||
l2 = torch.nn.Sequential(
|
|
||||||
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
|
|
||||||
)
|
|
||||||
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
||||||
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
||||||
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
||||||
|
@ -376,21 +372,10 @@ def test_linear8bitlt_accumulated_gradient():
|
||||||
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
|
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
|
||||||
|
|
||||||
|
|
||||||
threshold = [0.0, 2.0]
|
@pytest.mark.parametrize("threshold", [0.0, 2.0])
|
||||||
values = threshold
|
|
||||||
names = [f"threshold_{vals}" for vals in values]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
|
||||||
@pytest.mark.parametrize("memory_efficient_backward", [False])
|
@pytest.mark.parametrize("memory_efficient_backward", [False])
|
||||||
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||||
l1 = (
|
l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
|
||||||
bnb.nn.Linear8bitLt(
|
|
||||||
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
|
||||||
)
|
|
||||||
.cuda()
|
|
||||||
.half()
|
|
||||||
)
|
|
||||||
assert l1.weight.dtype == torch.int8
|
assert l1.weight.dtype == torch.int8
|
||||||
|
|
||||||
l1.eval()
|
l1.eval()
|
||||||
|
@ -446,13 +431,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||||
assert mlp.fc1.weight.dtype == torch.int8
|
assert mlp.fc1.weight.dtype == torch.int8
|
||||||
assert mlp.fc2.weight.dtype == torch.int8
|
assert mlp.fc2.weight.dtype == torch.int8
|
||||||
|
|
||||||
mlp = (
|
mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
|
||||||
MLP8bit(
|
|
||||||
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
|
||||||
)
|
|
||||||
.half()
|
|
||||||
.to("cuda")
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||||
|
@ -504,10 +483,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||||
assert (idx == 0).sum().item() <= b1.numel() * 0.005
|
assert (idx == 0).sum().item() <= b1.numel() * 0.005
|
||||||
|
|
||||||
|
|
||||||
def test_linear8bitlt_fp32_bias():
|
@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
|
||||||
|
def test_linear_kbit_fp32_bias(module):
|
||||||
# casts model to fp16 -> int8 automatically
|
# casts model to fp16 -> int8 automatically
|
||||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
|
l1 = module(32, 64).cuda()
|
||||||
assert l1.weight.dtype == torch.int8
|
assert l1.weight.dtype in [torch.int8, torch.uint8]
|
||||||
assert l1.bias.dtype == torch.float32
|
assert l1.bias.dtype == torch.float32
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
|
@ -517,11 +497,12 @@ def test_linear8bitlt_fp32_bias():
|
||||||
assert l1.bias.dtype == torch.float16
|
assert l1.bias.dtype == torch.float16
|
||||||
|
|
||||||
# casts model to fp16 -> int8 automatically
|
# casts model to fp16 -> int8 automatically
|
||||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
|
l1 = module(32, 64, bias=False).cuda()
|
||||||
assert l1.weight.dtype == torch.int8
|
assert l1.weight.dtype in [torch.int8, torch.uint8]
|
||||||
assert l1.bias is None
|
assert l1.bias is None
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||||
o1 = l1(b1)
|
o1 = l1(b1)
|
||||||
assert l1.bias is None
|
assert l1.bias is None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user