Raise an error when loading a quantized checkpoint before quantization
This commit is contained in:
parent
ac3ab281e3
commit
cd4d904a4c
|
@ -248,6 +248,11 @@ class Linear8bitLt(nn.Linear):
|
|||
for key in unexpected_keys:
|
||||
input_name = key[len(prefix):]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't call them directly without
|
||||
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()")
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
unexpected_keys.remove(key)
|
||||
|
|
Loading…
Reference in New Issue
Block a user