Raise an error when loading a quantized checkpoint before quantization

This commit is contained in:
Max Ryabinin 2023-02-25 06:01:34 +01:00
parent ac3ab281e3
commit cd4d904a4c

View File

@ -248,6 +248,11 @@ class Linear8bitLt(nn.Linear):
for key in unexpected_keys:
input_name = key[len(prefix):]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
unexpected_keys.remove(key)