Added test for Param4bit.to() and fixed double quant behavior.

This commit is contained in:
Tim Dettmers 2023-07-09 17:16:50 -07:00
parent 6a905be5ce
commit cef519c89e
3 changed files with 4 additions and 5 deletions

View File

@ -831,8 +831,6 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if compress_statistics:
offset = absmax.mean()
absmax -= offset
#code = create_custom_map().to(absmax.device)
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
del absmax
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]

View File

@ -188,9 +188,9 @@ class Params4bit(torch.nn.Parameter):
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s[-2][0] = s[-2][0].to(device) # offset
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
s[-3][0] = s[-3][0].to(device) # offset
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad, quant_state=self.quant_state,
blocksize=self.blocksize, compress_statistics=self.compress_statistics,

View File

@ -535,6 +535,7 @@ def test_kbit_backprop(module):
kbit[1].bias.detach().copy_(ref[1].bias)
ref = ref.half().cuda()
kbit = kbit.half().cuda()
kbit = kbit.half().to('cuda')
errs1 = []
errs2 = []