Improve memory efficiency of 8-bit serialization
This commit is contained in:
parent
4fb37d45c1
commit
f734076e94
|
@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
|
|||
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional
|
||||
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
|
||||
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||
|
||||
|
@ -306,7 +306,6 @@ class Int8Params(torch.nn.Parameter):
|
|||
return new_param
|
||||
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||
|
@ -324,50 +323,58 @@ class Linear8bitLt(nn.Linear):
|
|||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
|
||||
# reorder weight layout back from ampere/turing to row
|
||||
reorder_layout = True
|
||||
weight_clone = self.weight.data.clone()
|
||||
else:
|
||||
reorder_layout = False
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
try:
|
||||
if reorder_layout:
|
||||
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
|
||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||
scb_name = "SCB"
|
||||
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
# case 1: .cuda was called, SCB is in self.weight
|
||||
param_from_weight = getattr(self.weight, scb_name)
|
||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||
param_from_state = getattr(self.state, scb_name)
|
||||
# case 3: SCB is in self.state, weight layout reordered after first forward()
|
||||
layout_reordered = self.state.CxB is not None
|
||||
|
||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||
weight_name = "SCB"
|
||||
key_name = prefix + f"{scb_name}"
|
||||
format_name = prefix + "weight_format"
|
||||
|
||||
# case 1: .cuda was called, SCB is in self.weight
|
||||
param_from_weight = getattr(self.weight, weight_name)
|
||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||
param_from_state = getattr(self.state, weight_name)
|
||||
|
||||
key_name = prefix + f"{weight_name}"
|
||||
if not self.state.has_fp16_weights:
|
||||
if param_from_weight is not None:
|
||||
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||
elif not self.state.has_fp16_weights and param_from_state is not None:
|
||||
destination[format_name] = "row"
|
||||
elif param_from_state is not None and not layout_reordered:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
finally:
|
||||
if reorder_layout:
|
||||
self.weight.data = weight_clone
|
||||
destination[format_name] = "row"
|
||||
elif param_from_state is not None:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
destination[format_name] = self.state.formatB
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
for key in unexpected_keys:
|
||||
unexpected_copy = list(unexpected_keys)
|
||||
|
||||
for key in unexpected_copy:
|
||||
input_name = key[len(prefix):]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't call them directly without
|
||||
# buffers not yet initialized, can't access them directly without quantizing first
|
||||
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()")
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
|
||||
if self.state.SCB is not None:
|
||||
self.state.SCB = self.weight.SCB
|
||||
|
||||
unexpected_keys.remove(key)
|
||||
if input_name == "weight_format":
|
||||
input_param = state_dict[key]
|
||||
if input_param != "row":
|
||||
tile_indices = get_tile_inds(input_param, self.weight.device)
|
||||
self.weight.data = self.weight.CB = undo_layout(self.weight.data, tile_indices)
|
||||
unexpected_keys.remove(key)
|
||||
|
||||
def init_8bit_state(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user