Fixed ParamsIn4 init; fixed PyTorch 2.0 test failure.
This commit is contained in:
parent
4ea489d3bf
commit
1ccb7bdec6
|
@ -136,12 +136,14 @@ class Embedding(torch.nn.Embedding):
|
|||
class Params4bit(torch.nn.Parameter):
|
||||
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
|
||||
cls.quant_state = None
|
||||
cls.blocksize = blocksize
|
||||
cls.compress_statistics = compress_statistics
|
||||
cls.quant_type = quant_type
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
self.blocksize = blocksize
|
||||
self.compress_statistics = compress_statistics
|
||||
self.quant_type = quant_type
|
||||
return self
|
||||
|
||||
def cuda(self, device):
|
||||
w = self.data.contiguous().half().cuda(device)
|
||||
|
@ -177,16 +179,10 @@ class Params4bit(torch.nn.Parameter):
|
|||
class Linear4bit(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def init_8bit_state(self):
|
||||
pass
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
@ -197,7 +193,7 @@ class Linear4bit(nn.Linear):
|
|||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.half()
|
||||
bias = None if self.bias is None else self.bias.half(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
|
|
@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
|
|||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
iters = 32
|
||||
iters = 1
|
||||
formatB = F.get_special_format_str()
|
||||
|
||||
A = torch.randn(batch, seq, model, device="cuda").half()
|
||||
|
@ -2317,7 +2317,7 @@ def test_bench_4bit_dequant(quant_type):
|
|||
#print(max_theoretical_s*1e6)
|
||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||
|
||||
iters = 500
|
||||
iters = 5
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
|
|
|
@ -558,14 +558,17 @@ def test_kbit_backprop(module):
|
|||
relerrs1.append(relerr1.mean().item())
|
||||
relerrs2.append(relerr2.mean().item())
|
||||
|
||||
|
||||
#torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
|
||||
#torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
|
||||
if isinstance(module, bnb.nn.Linear8bitLt):
|
||||
torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
|
||||
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
|
||||
else:
|
||||
torch.testing.assert_allclose(grad1, grad2, atol=0.015, rtol=0.05)
|
||||
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.02, rtol=0.05)
|
||||
ref.zero_grad()
|
||||
kbit.zero_grad()
|
||||
|
||||
assert kbit[0].weight.grad.sum().item() == 0
|
||||
assert kbit[0].bias.grad.sum().item() == 0
|
||||
assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
|
||||
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
|
||||
print('out', sum(errs1)/len(errs1))
|
||||
print('grad', sum(errs2)/len(errs2))
|
||||
print('rel out', sum(relerrs1)/len(relerrs1))
|
||||
|
|
Loading…
Reference in New Issue
Block a user