Added test for Param4bit.to() and fixed double quant behavior.
This commit is contained in:
parent
6a905be5ce
commit
cef519c89e
|
@ -831,8 +831,6 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
||||||
if compress_statistics:
|
if compress_statistics:
|
||||||
offset = absmax.mean()
|
offset = absmax.mean()
|
||||||
absmax -= offset
|
absmax -= offset
|
||||||
#code = create_custom_map().to(absmax.device)
|
|
||||||
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
|
|
||||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
|
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
|
||||||
del absmax
|
del absmax
|
||||||
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
|
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
|
||||||
|
|
|
@ -188,9 +188,9 @@ class Params4bit(torch.nn.Parameter):
|
||||||
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
|
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
|
||||||
|
|
||||||
# for 8-bit
|
# for 8-bit
|
||||||
s[-2][0] = s[-2][0].to(device) # offset
|
s[-3][0] = s[-3][0].to(device) # offset
|
||||||
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
|
s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics
|
||||||
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
|
s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook
|
||||||
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
|
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
|
||||||
requires_grad=self.requires_grad, quant_state=self.quant_state,
|
requires_grad=self.requires_grad, quant_state=self.quant_state,
|
||||||
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
|
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
|
||||||
|
|
|
@ -535,6 +535,7 @@ def test_kbit_backprop(module):
|
||||||
kbit[1].bias.detach().copy_(ref[1].bias)
|
kbit[1].bias.detach().copy_(ref[1].bias)
|
||||||
ref = ref.half().cuda()
|
ref = ref.half().cuda()
|
||||||
kbit = kbit.half().cuda()
|
kbit = kbit.half().cuda()
|
||||||
|
kbit = kbit.half().to('cuda')
|
||||||
|
|
||||||
errs1 = []
|
errs1 = []
|
||||||
errs2 = []
|
errs2 = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user