Raise an error when loading a quantized checkpoint before quantization
This commit is contained in:
parent
ac3ab281e3
commit
cd4d904a4c
|
@ -248,6 +248,11 @@ class Linear8bitLt(nn.Linear):
|
||||||
for key in unexpected_keys:
|
for key in unexpected_keys:
|
||||||
input_name = key[len(prefix):]
|
input_name = key[len(prefix):]
|
||||||
if input_name == "SCB":
|
if input_name == "SCB":
|
||||||
|
if self.weight.SCB is None:
|
||||||
|
# buffers not yet initialized, can't call them directly without
|
||||||
|
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||||
|
"not supported. Please call module.cuda() before module.load_state_dict()")
|
||||||
|
|
||||||
input_param = state_dict[key]
|
input_param = state_dict[key]
|
||||||
self.weight.SCB.copy_(input_param)
|
self.weight.SCB.copy_(input_param)
|
||||||
unexpected_keys.remove(key)
|
unexpected_keys.remove(key)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user