Added test for Param4bit.to() and fixed double quant behavior.
This commit is contained in:
parent
6a905be5ce
commit
cef519c89e
|
@ -831,8 +831,6 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
if compress_statistics:
|
||||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
#code = create_custom_map().to(absmax.device)
|
||||
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
|
||||
del absmax
|
||||
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
|
||||
|
|
|
@ -188,9 +188,9 @@ class Params4bit(torch.nn.Parameter):
|
|||
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
|
||||
|
||||
# for 8-bit
|
||||
s[-2][0] = s[-2][0].to(device) # offset
|
||||
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
|
||||
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
|
||||
s[-3][0] = s[-3][0].to(device) # offset
|
||||
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics
|
||||
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook
|
||||
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
|
||||
requires_grad=self.requires_grad, quant_state=self.quant_state,
|
||||
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
|
||||
|
|
|
@ -535,6 +535,7 @@ def test_kbit_backprop(module):
|
|||
kbit[1].bias.detach().copy_(ref[1].bias)
|
||||
ref = ref.half().cuda()
|
||||
kbit = kbit.half().cuda()
|
||||
kbit = kbit.half().to('cuda')
|
||||
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user