Improve memory efficiency of 8-bit serialization

This commit is contained in:
Max Ryabinin 2023-06-09 21:39:57 +02:00
parent 4fb37d45c1
commit f734076e94

View File

@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
import bitsandbytes as bnb
import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
@ -306,7 +306,6 @@ class Int8Params(torch.nn.Parameter):
return new_param
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
@ -324,50 +323,58 @@ class Linear8bitLt(nn.Linear):
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
# reorder weight layout back from ampere/turing to row
reorder_layout = True
weight_clone = self.weight.data.clone()
else:
reorder_layout = False
super()._save_to_state_dict(destination, prefix, keep_vars)
try:
if reorder_layout:
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
scb_name = "SCB"
super()._save_to_state_dict(destination, prefix, keep_vars)
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, scb_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, scb_name)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered = self.state.CxB is not None
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name = "SCB"
key_name = prefix + f"{scb_name}"
format_name = prefix + "weight_format"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, weight_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, weight_name)
key_name = prefix + f"{weight_name}"
if not self.state.has_fp16_weights:
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None:
destination[format_name] = "row"
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
finally:
if reorder_layout:
self.weight.data = weight_clone
destination[format_name] = "row"
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
for key in unexpected_keys:
unexpected_copy = list(unexpected_keys)
for key in unexpected_copy:
input_name = key[len(prefix):]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")
input_param = state_dict[key]
self.weight.SCB.copy_(input_param)
if self.state.SCB is not None:
self.state.SCB = self.weight.SCB
unexpected_keys.remove(key)
if input_name == "weight_format":
input_param = state_dict[key]
if input_param != "row":
tile_indices = get_tile_inds(input_param, self.weight.device)
self.weight.data = self.weight.CB = undo_layout(self.weight.data, tile_indices)
unexpected_keys.remove(key)
def init_8bit_state(self):