Fixed ParamsIn4 init; fixed PyTorch 2.0 test failure.

This commit is contained in:
Tim Dettmers 2023-04-03 18:47:00 -07:00
parent 4ea489d3bf
commit 1ccb7bdec6
3 changed files with 17 additions and 18 deletions

View File

@ -136,12 +136,14 @@ class Embedding(torch.nn.Embedding):
class Params4bit(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
cls.quant_state = None
cls.blocksize = blocksize
cls.compress_statistics = compress_statistics
cls.quant_type = quant_type
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize
self.compress_statistics = compress_statistics
self.quant_type = quant_type
return self
def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
@ -177,16 +179,10 @@ class Params4bit(torch.nn.Parameter):
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
super().__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState()
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
def init_8bit_state(self):
pass
def forward(self, x: torch.Tensor):
self.state.is_training = self.training
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
@ -197,7 +193,7 @@ class Linear4bit(nn.Linear):
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.half()
bias = None if self.bias is None else self.bias.half(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype)

View File

@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
iters = 32
iters = 1
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
@ -2317,7 +2317,7 @@ def test_bench_4bit_dequant(quant_type):
#print(max_theoretical_s*1e6)
b = torch.randn(128, 1024*12, device='cuda').half()
iters = 500
iters = 5
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):

View File

@ -558,14 +558,17 @@ def test_kbit_backprop(module):
relerrs1.append(relerr1.mean().item())
relerrs2.append(relerr2.mean().item())
#torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
#torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
if isinstance(module, bnb.nn.Linear8bitLt):
torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else:
torch.testing.assert_allclose(grad1, grad2, atol=0.015, rtol=0.05)
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.02, rtol=0.05)
ref.zero_grad()
kbit.zero_grad()
assert kbit[0].weight.grad.sum().item() == 0
assert kbit[0].bias.grad.sum().item() == 0
assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
print('out', sum(errs1)/len(errs1))
print('grad', sum(errs2)/len(errs2))
print('rel out', sum(relerrs1)/len(relerrs1))