From cef519c89ed04fdd6f3c09a672f8520532a89994 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 9 Jul 2023 17:16:50 -0700 Subject: [PATCH] Added test for Param4bit.to() and fixed double quant behavior. --- bitsandbytes/functional.py | 2 -- bitsandbytes/nn/modules.py | 6 +++--- tests/test_modules.py | 1 + 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 78b5f4b..c5514ed 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -831,8 +831,6 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz if compress_statistics: offset = absmax.mean() absmax -= offset - #code = create_custom_map().to(absmax.device) - #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype] diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3284921..2407afb 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -188,9 +188,9 @@ class Params4bit(torch.nn.Parameter): #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax # for 8-bit - s[-2][0] = s[-2][0].to(device) # offset - s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics - s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook + s[-3][0] = s[-3][0].to(device) # offset + s[-3][1][0] = s[-3][1][0].to(device) # nested quantiation state statitics + s[-3][1][1] = s[-3][1][1].to(device) # nested quantiation codebook new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=self.quant_state, blocksize=self.blocksize, compress_statistics=self.compress_statistics, diff --git a/tests/test_modules.py b/tests/test_modules.py index d0a9051..a187484 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -535,6 +535,7 @@ def test_kbit_backprop(module): kbit[1].bias.detach().copy_(ref[1].bias) ref = ref.half().cuda() kbit = kbit.half().cuda() + kbit = kbit.half().to('cuda') errs1 = [] errs2 = []