Revert the layout if weights were reordered
This commit is contained in:
parent
cd4d904a4c
commit
cc608c04c2
|
@ -9,6 +9,8 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
import bitsandbytes.functional
|
||||||
|
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
|
||||||
from bitsandbytes.optim import GlobalOptimManager
|
from bitsandbytes.optim import GlobalOptimManager
|
||||||
|
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
@ -210,7 +212,7 @@ class Int8Params(torch.nn.Parameter):
|
||||||
|
|
||||||
class Linear8bitLt(nn.Linear):
|
class Linear8bitLt(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||||
self.state = bnb.MatmulLtState()
|
self.state = bnb.MatmulLtState()
|
||||||
|
@ -225,21 +227,48 @@ class Linear8bitLt(nn.Linear):
|
||||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
|
||||||
|
# reorder weight layout back from ampere/turing to row
|
||||||
|
reorder_layout = True
|
||||||
|
weight_clone = self.weight.data.clone()
|
||||||
|
else:
|
||||||
|
reorder_layout = False
|
||||||
|
|
||||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
try:
|
||||||
weight_name = "SCB"
|
if reorder_layout:
|
||||||
|
if self.state.tile_indices is None:
|
||||||
|
order, tile_size = self.state.formatB, self.state.get_tile_size()
|
||||||
|
transform = lambda x: \
|
||||||
|
bitsandbytes.functional.transform(x.to(self.weight.data.device), from_order="row",
|
||||||
|
to_order=order)[0].to(x.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(
|
||||||
|
self.state.CxB.device)
|
||||||
|
|
||||||
# case 1: .cuda was called, SCB is in self.weight
|
CB = (
|
||||||
param_from_weight = getattr(self.weight, weight_name)
|
undo_layout(self.state.CxB, self.state.tile_indices)
|
||||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
)
|
||||||
param_from_state = getattr(self.state, weight_name)
|
|
||||||
|
|
||||||
key_name = prefix + f"{weight_name}"
|
self.weight.data = CB
|
||||||
if param_from_weight is not None:
|
|
||||||
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
elif not self.state.has_fp16_weights and param_from_state is not None:
|
|
||||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||||
|
weight_name = "SCB"
|
||||||
|
|
||||||
|
# case 1: .cuda was called, SCB is in self.weight
|
||||||
|
param_from_weight = getattr(self.weight, weight_name)
|
||||||
|
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||||
|
param_from_state = getattr(self.state, weight_name)
|
||||||
|
|
||||||
|
key_name = prefix + f"{weight_name}"
|
||||||
|
if param_from_weight is not None:
|
||||||
|
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||||
|
elif not self.state.has_fp16_weights and param_from_state is not None:
|
||||||
|
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||||
|
finally:
|
||||||
|
if reorder_layout:
|
||||||
|
self.weight.data = weight_clone
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs):
|
missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user