Fixed bug in Linear8bitLt, when the bias is None.
This commit is contained in:
parent
b00cc9137f
commit
9d60b3c527
|
@ -248,10 +248,10 @@ class Linear8bitLt(nn.Linear):
|
|||
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
if self.bias.dtype != torch.float16:
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != torch.float16:
|
||||
self.bias.data = self.bias.data.half()
|
||||
# assert not self.state.has_fp16_weights
|
||||
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -18,7 +18,7 @@ def read(fname):
|
|||
|
||||
setup(
|
||||
name=f"bitsandbytes",
|
||||
version=f"0.32.0",
|
||||
version=f"0.32.1",
|
||||
author="Tim Dettmers",
|
||||
author_email="dettmers@cs.washington.edu",
|
||||
description="8-bit optimizers and matrix multiplication routines.",
|
||||
|
|
|
@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
|||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
assert mlp.fc1.weight.device.type == "cuda"
|
||||
assert mlp.fc2.weight.device.type == "cuda"
|
||||
|
||||
|
||||
def test_linear8bitlt_fp32_bias():
|
||||
# casts model to fp16 -> int8 automatically
|
||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
|
||||
assert l1.weight.dtype == torch.int8
|
||||
assert l1.bias.dtype == torch.float32
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||
# casts bias to fp32
|
||||
o1 = l1(b1)
|
||||
assert l1.bias.dtype == torch.float16
|
||||
|
||||
# casts model to fp16 -> int8 automatically
|
||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
|
||||
assert l1.weight.dtype == torch.int8
|
||||
assert l1.bias is None
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||
o1 = l1(b1)
|
||||
assert l1.bias is None
|
||||
|
|
Loading…
Reference in New Issue
Block a user