Rearrange the weights directly in state dict before loading
This commit is contained in:
parent
f734076e94
commit
c1f3f56d2c
|
@ -306,6 +306,15 @@ class Int8Params(torch.nn.Parameter):
|
||||||
return new_param
|
return new_param
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
weight = state_dict[f"{prefix}weight"]
|
||||||
|
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
|
||||||
|
|
||||||
|
if weight_format != "row":
|
||||||
|
tile_indices = get_tile_inds(weight_format, weight.device)
|
||||||
|
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Linear):
|
class Linear8bitLt(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||||
|
@ -321,6 +330,7 @@ class Linear8bitLt(nn.Linear):
|
||||||
self.state.use_pool = True
|
self.state.use_pool = True
|
||||||
|
|
||||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||||
|
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
@ -370,12 +380,6 @@ class Linear8bitLt(nn.Linear):
|
||||||
self.state.SCB = self.weight.SCB
|
self.state.SCB = self.weight.SCB
|
||||||
|
|
||||||
unexpected_keys.remove(key)
|
unexpected_keys.remove(key)
|
||||||
if input_name == "weight_format":
|
|
||||||
input_param = state_dict[key]
|
|
||||||
if input_param != "row":
|
|
||||||
tile_indices = get_tile_inds(input_param, self.weight.device)
|
|
||||||
self.weight.data = self.weight.CB = undo_layout(self.weight.data, tile_indices)
|
|
||||||
unexpected_keys.remove(key)
|
|
||||||
|
|
||||||
def init_8bit_state(self):
|
def init_8bit_state(self):
|
||||||
self.state.CB = self.weight.CB
|
self.state.CB = self.weight.CB
|
||||||
|
|
Loading…
Reference in New Issue
Block a user