2022-08-01 16:30:29 +00:00
|
|
|
import shlex
|
|
|
|
import subprocess
|
2022-08-02 02:22:41 +00:00
|
|
|
from typing import Tuple
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-08-03 04:26:50 +00:00
|
|
|
|
2022-08-01 16:30:29 +00:00
|
|
|
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
|
|
|
def _decode(subprocess_err_out_tuple):
|
|
|
|
return tuple(
|
|
|
|
to_decode.decode("UTF-8").strip()
|
|
|
|
for to_decode in subprocess_err_out_tuple
|
|
|
|
)
|
|
|
|
|
|
|
|
def execute_and_return_decoded_std_streams(command_string):
|
|
|
|
return _decode(
|
|
|
|
subprocess.Popen(
|
|
|
|
shlex.split(command_string),
|
|
|
|
stdout=subprocess.PIPE,
|
|
|
|
stderr=subprocess.PIPE,
|
|
|
|
).communicate()
|
|
|
|
)
|
|
|
|
|
2022-08-02 02:22:41 +00:00
|
|
|
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
|
2022-08-01 16:30:29 +00:00
|
|
|
return std_out, std_err
|
2023-03-27 16:12:57 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
|
|
|
|
"""
|
|
|
|
Replace linear modules with a new Linear module.
|
|
|
|
Parameters:
|
|
|
|
model (`torch.nn.Module`):
|
|
|
|
Input model or `torch.nn.Module` as the function is run recursively.
|
|
|
|
linear_replacement (`torch.nn.Module`):
|
|
|
|
The linear module that replaces the old one. Only expects standard arguments.
|
|
|
|
If other arguments need to be passed, use a lambda.
|
|
|
|
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
|
|
|
List of modules names not to convert. Defaults to `lm_head`.
|
|
|
|
copy_weights (`bool`):
|
|
|
|
Copy the weights from the old linear module to the new one
|
|
|
|
post_processing_fun_name (`str`):
|
|
|
|
A function name of the replacement linear class that is called
|
|
|
|
after processing.
|
|
|
|
"""
|
|
|
|
for name, module in model.named_children():
|
|
|
|
if len(list(module.children())) > 0:
|
|
|
|
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
|
|
|
|
|
|
|
|
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
|
|
|
old_module = model._modules[name]
|
|
|
|
model._modules[name] = linear_replacement(
|
|
|
|
module.in_features,
|
|
|
|
module.out_features,
|
|
|
|
module.bias is not None,
|
|
|
|
)
|
|
|
|
if copy_weights:
|
|
|
|
model._modules[name].weight = old_module.weight
|
|
|
|
model._modules[name].bias = old_module.bias
|
|
|
|
|
|
|
|
if post_processing_function is not None:
|
|
|
|
func = getattr(module, post_processing_function, None)
|
|
|
|
if func is not None: func(module)
|
|
|
|
return model
|
|
|
|
|