Added bias test for LinearFP4 and basic test.

This commit is contained in:
Tim Dettmers 2023-02-05 06:29:52 -08:00
parent c361f84239
commit c0c352b379
3 changed files with 16 additions and 35 deletions

View File

@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding
from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4

View File

@ -188,9 +188,9 @@ class LinearFP4(nn.Linear):
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, 'state', None) is None:
print('FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
out = bnb.matmul_fp(x, self.weight, bias=self.bias, state=self.weight.state)
if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state)
return out

View File

@ -330,12 +330,8 @@ def test_linear8bitlt_inference(threshold):
def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
)
l2 = torch.nn.Sequential(
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
)
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
@ -376,21 +372,10 @@ def test_linear8bitlt_accumulated_gradient():
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
threshold = [0.0, 2.0]
values = threshold
names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
@pytest.mark.parametrize("threshold", [0.0, 2.0])
@pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
bnb.nn.Linear8bitLt(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.cuda()
.half()
)
l1 = ( bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
assert l1.weight.dtype == torch.int8
l1.eval()
@ -446,13 +431,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
mlp = (
MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
.half()
.to("cuda")
)
mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
@ -504,10 +483,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert (idx == 0).sum().item() <= b1.numel() * 0.005
def test_linear8bitlt_fp32_bias():
@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
def test_linear_kbit_fp32_bias(module):
# casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
assert l1.weight.dtype == torch.int8
l1 = module(32, 64).cuda()
assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias.dtype == torch.float32
for i in range(100):
@ -517,11 +497,12 @@ def test_linear8bitlt_fp32_bias():
assert l1.bias.dtype == torch.float16
# casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
assert l1.weight.dtype == torch.int8
l1 = module(32, 64, bias=False).cuda()
assert l1.weight.dtype in [torch.int8, torch.uint8]
assert l1.bias is None
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
assert l1.bias is None