From cd4d904a4ccc80c444e460d3aef20705895d2051 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Sat, 25 Feb 2023 06:01:34 +0100 Subject: [PATCH] Raise an error when loading a quantized checkpoint before quantization --- bitsandbytes/nn/modules.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 754ba20..65b2102 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -248,6 +248,11 @@ class Linear8bitLt(nn.Linear): for key in unexpected_keys: input_name = key[len(prefix):] if input_name == "SCB": + if self.weight.SCB is None: + # buffers not yet initialized, can't call them directly without + raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()") + input_param = state_dict[key] self.weight.SCB.copy_(input_param) unexpected_keys.remove(key)