Rearrange the weights directly in state dict before loading

This commit is contained in:
Max Ryabinin 2023-06-09 21:58:39 +02:00
parent f734076e94
commit c1f3f56d2c

View File

@ -306,6 +306,15 @@ class Int8Params(torch.nn.Parameter):
return new_param return new_param
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
weight = state_dict[f"{prefix}weight"]
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
class Linear8bitLt(nn.Linear): class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None): memory_efficient_backward=False, threshold=0.0, index=None):
@ -321,6 +330,7 @@ class Linear8bitLt(nn.Linear):
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars) super()._save_to_state_dict(destination, prefix, keep_vars)
@ -370,12 +380,6 @@ class Linear8bitLt(nn.Linear):
self.state.SCB = self.weight.SCB self.state.SCB = self.weight.SCB
unexpected_keys.remove(key) unexpected_keys.remove(key)
if input_name == "weight_format":
input_param = state_dict[key]
if input_param != "row":
tile_indices = get_tile_inds(input_param, self.weight.device)
self.weight.data = self.weight.CB = undo_layout(self.weight.data, tile_indices)
unexpected_keys.remove(key)
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB