Merge pull request #503 from TimDettmers/efficient_8bit_serialize
Make 8-bit serialization more memory-efficient (v2)
This commit is contained in:
commit
2d321a7524
|
@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tile_size(format):
|
||||||
|
assert format in (
|
||||||
|
"col_turing",
|
||||||
|
"col_ampere",
|
||||||
|
), f"please find this assert and manually enter tile size for {format}"
|
||||||
|
return (8, 32) if format == "col_turing" else (32, 32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tile_inds(format, device):
|
||||||
|
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatmulLtState:
|
class MatmulLtState:
|
||||||
_tile_indices: Optional[torch.Tensor] = None
|
_tile_indices: Optional[torch.Tensor] = None
|
||||||
|
@ -267,20 +280,10 @@ class MatmulLtState:
|
||||||
self.SBt = None
|
self.SBt = None
|
||||||
self.CBt = None
|
self.CBt = None
|
||||||
|
|
||||||
def get_tile_size(self):
|
|
||||||
assert self.formatB in (
|
|
||||||
"col_turing",
|
|
||||||
"col_ampere",
|
|
||||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
|
||||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tile_indices(self):
|
def tile_indices(self):
|
||||||
if self._tile_indices is None:
|
if self._tile_indices is None:
|
||||||
device = self.CxB.device
|
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
|
||||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
|
||||||
return self._tile_indices
|
return self._tile_indices
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import bitsandbytes.functional
|
import bitsandbytes.functional
|
||||||
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
|
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
|
||||||
from bitsandbytes.optim import GlobalOptimManager
|
from bitsandbytes.optim import GlobalOptimManager
|
||||||
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||||
|
|
||||||
|
@ -306,6 +306,17 @@ class Int8Params(torch.nn.Parameter):
|
||||||
return new_param
|
return new_param
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
weight = state_dict.get(f"{prefix}weight")
|
||||||
|
if weight is None:
|
||||||
|
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
|
||||||
|
return
|
||||||
|
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
|
||||||
|
|
||||||
|
if weight_format != "row":
|
||||||
|
tile_indices = get_tile_inds(weight_format, weight.device)
|
||||||
|
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Linear):
|
class Linear8bitLt(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||||
|
@ -322,52 +333,55 @@ class Linear8bitLt(nn.Linear):
|
||||||
self.state.use_pool = True
|
self.state.use_pool = True
|
||||||
|
|
||||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||||
|
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
|
|
||||||
# reorder weight layout back from ampere/turing to row
|
|
||||||
reorder_layout = True
|
|
||||||
weight_clone = self.weight.data.clone()
|
|
||||||
else:
|
|
||||||
reorder_layout = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if reorder_layout:
|
|
||||||
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
|
|
||||||
|
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||||
weight_name = "SCB"
|
scb_name = "SCB"
|
||||||
|
|
||||||
# case 1: .cuda was called, SCB is in self.weight
|
# case 1: .cuda was called, SCB is in self.weight
|
||||||
param_from_weight = getattr(self.weight, weight_name)
|
param_from_weight = getattr(self.weight, scb_name)
|
||||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||||
param_from_state = getattr(self.state, weight_name)
|
param_from_state = getattr(self.state, scb_name)
|
||||||
|
# case 3: SCB is in self.state, weight layout reordered after first forward()
|
||||||
|
layout_reordered = self.state.CxB is not None
|
||||||
|
|
||||||
key_name = prefix + f"{weight_name}"
|
key_name = prefix + f"{scb_name}"
|
||||||
|
format_name = prefix + "weight_format"
|
||||||
|
|
||||||
|
if not self.state.has_fp16_weights:
|
||||||
if param_from_weight is not None:
|
if param_from_weight is not None:
|
||||||
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||||
elif not self.state.has_fp16_weights and param_from_state is not None:
|
destination[format_name] = "row"
|
||||||
|
elif param_from_state is not None and not layout_reordered:
|
||||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||||
finally:
|
destination[format_name] = "row"
|
||||||
if reorder_layout:
|
elif param_from_state is not None:
|
||||||
self.weight.data = weight_clone
|
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||||
|
destination[format_name] = self.state.formatB
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs):
|
missing_keys, unexpected_keys, error_msgs):
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||||
error_msgs)
|
error_msgs)
|
||||||
for key in unexpected_keys:
|
unexpected_copy = list(unexpected_keys)
|
||||||
|
|
||||||
|
for key in unexpected_copy:
|
||||||
input_name = key[len(prefix):]
|
input_name = key[len(prefix):]
|
||||||
if input_name == "SCB":
|
if input_name == "SCB":
|
||||||
if self.weight.SCB is None:
|
if self.weight.SCB is None:
|
||||||
# buffers not yet initialized, can't call them directly without
|
# buffers not yet initialized, can't access them directly without quantizing first
|
||||||
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||||
"not supported. Please call module.cuda() before module.load_state_dict()")
|
"not supported. Please call module.cuda() before module.load_state_dict()")
|
||||||
|
|
||||||
input_param = state_dict[key]
|
input_param = state_dict[key]
|
||||||
self.weight.SCB.copy_(input_param)
|
self.weight.SCB.copy_(input_param)
|
||||||
|
|
||||||
|
if self.state.SCB is not None:
|
||||||
|
self.state.SCB = self.weight.SCB
|
||||||
|
|
||||||
unexpected_keys.remove(key)
|
unexpected_keys.remove(key)
|
||||||
|
|
||||||
def init_8bit_state(self):
|
def init_8bit_state(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user