2023-08-02 21:53:35 +00:00
|
|
|
"""
|
|
|
|
# https://github.com/enhuiz/pytorch-training-utilities
|
|
|
|
"""
|
|
|
|
|
|
|
|
from .distributed import global_rank, local_rank, global_leader_only
|
|
|
|
|
|
|
|
import gc
|
|
|
|
import logging
|
|
|
|
import pandas as pd
|
2024-06-25 18:41:29 +00:00
|
|
|
import numpy as np
|
2023-08-02 21:53:35 +00:00
|
|
|
import re
|
|
|
|
import torch
|
2024-06-25 18:41:29 +00:00
|
|
|
import random
|
|
|
|
import time
|
2024-08-02 01:12:06 +00:00
|
|
|
import psutil
|
|
|
|
import math
|
2024-08-29 18:27:16 +00:00
|
|
|
import logging
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
from coloredlogs import ColoredFormatter
|
|
|
|
from logging import StreamHandler
|
|
|
|
from pathlib import Path
|
|
|
|
from torch import Tensor, nn
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
from typing import Callable, TypeVar, overload
|
2024-08-02 01:12:06 +00:00
|
|
|
from contextlib import contextmanager
|
2024-10-05 03:18:20 +00:00
|
|
|
|
|
|
|
from time import perf_counter
|
|
|
|
from datetime import datetime
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
T = TypeVar("T")
|
|
|
|
|
2024-10-17 22:06:48 +00:00
|
|
|
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
|
|
|
|
is_obj = hasattr( source, "__dict__" )
|
|
|
|
if parent_is_obj is None:
|
|
|
|
parent_is_obj = is_obj
|
|
|
|
haystack = source.__dict__ if is_obj else source
|
|
|
|
keep = {}
|
|
|
|
missing = []
|
|
|
|
for k, v in dest.items():
|
|
|
|
if k in haystack or (parent_is_obj and not is_obj and source == {}):
|
|
|
|
keep[k] = dest[k]
|
|
|
|
else:
|
|
|
|
missing.append(".".join(path + [k]))
|
|
|
|
|
|
|
|
if recurse and isinstance( v, dict ):
|
|
|
|
keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing )
|
|
|
|
missing += m
|
|
|
|
return (keep, missing) if return_missing else keep
|
|
|
|
|
2024-10-05 03:18:20 +00:00
|
|
|
class timer:
|
|
|
|
def __init__(self, msg="Elapsed time:", callback=None):
|
|
|
|
self.msg = msg
|
|
|
|
self.callback = callback
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
self.start = perf_counter()
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
|
|
msg = f'{self.msg} {(perf_counter() - self.start):.9f}s'
|
|
|
|
|
|
|
|
if self.callback:
|
|
|
|
self.callback(msg)
|
|
|
|
|
|
|
|
print(f'[{datetime.now().isoformat()}] {msg}')
|
|
|
|
|
2024-01-27 01:41:12 +00:00
|
|
|
def truncate_json( str ):
|
|
|
|
|
|
|
|
def fun( match ):
|
|
|
|
return "{:.4f}".format(float(match.group()))
|
|
|
|
|
|
|
|
return re.sub(r"\d+\.\d{8,}", fun, str)
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
def do_gc():
|
|
|
|
gc.collect()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
def flatten_dict(d):
|
|
|
|
records = pd.json_normalize(d).to_dict(orient="records")
|
|
|
|
return records[0] if records else {}
|
|
|
|
|
|
|
|
|
2024-06-25 18:41:29 +00:00
|
|
|
def set_seed(seed=None):
|
|
|
|
if not seed:
|
|
|
|
seed = int(time.time())
|
|
|
|
|
|
|
|
random.seed(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
2024-10-18 21:55:00 +00:00
|
|
|
return seed
|
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
def _get_named_modules(module, attrname):
|
|
|
|
for name, module in module.named_modules():
|
|
|
|
if hasattr(module, attrname):
|
|
|
|
yield name, module
|
|
|
|
|
|
|
|
|
|
|
|
def gather_attribute(module, attrname, delete=True, prefix=True):
|
|
|
|
ret = {}
|
|
|
|
for name, module in _get_named_modules(module, attrname):
|
|
|
|
ret[name] = getattr(module, attrname)
|
|
|
|
if delete:
|
|
|
|
try:
|
|
|
|
delattr(module, attrname)
|
|
|
|
except Exception as e:
|
|
|
|
raise RuntimeError(f"{name} {module} {attrname}") from e
|
|
|
|
if prefix:
|
|
|
|
ret = {attrname: ret}
|
|
|
|
ret = flatten_dict(ret)
|
|
|
|
# remove consecutive dots
|
|
|
|
ret = {re.sub(r"\.+", ".", k): v for k, v in ret.items()}
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
def dispatch_attribute(
|
|
|
|
module,
|
|
|
|
attrname,
|
|
|
|
value,
|
|
|
|
filter_fn: Callable[[nn.Module], bool] | None = None,
|
|
|
|
):
|
|
|
|
for _, module in _get_named_modules(module, attrname):
|
|
|
|
if filter_fn is None or filter_fn(module):
|
|
|
|
setattr(module, attrname, value)
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict_non_strict(model, state_dict, logger=None):
|
|
|
|
model_state_dict = model.state_dict()
|
|
|
|
provided = set(state_dict)
|
|
|
|
required = set(model_state_dict)
|
|
|
|
agreed = provided & required
|
|
|
|
for k in list(agreed):
|
|
|
|
if model_state_dict[k].shape != state_dict[k].shape:
|
|
|
|
agreed.remove(k)
|
|
|
|
provided.remove(k)
|
|
|
|
state_dict = {k: state_dict[k] for k in agreed}
|
|
|
|
if logger is not None and (diff := provided - required):
|
|
|
|
logger.warning(
|
|
|
|
f"Extra parameters are found. "
|
|
|
|
f"Provided but not required parameters: \n{diff}."
|
|
|
|
)
|
|
|
|
if logger is not None and (diff := required - provided):
|
|
|
|
logger.warning(
|
|
|
|
f"Some parameters are missing. "
|
|
|
|
f"Required but not provided parameters: \n{diff}."
|
|
|
|
)
|
|
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
class TqdmLoggingHandler(logging.Handler):
|
2023-08-03 03:57:10 +00:00
|
|
|
def __init__(self, level=logging.INFO):
|
2023-08-02 21:53:35 +00:00
|
|
|
super().__init__(level)
|
|
|
|
|
|
|
|
def emit(self, record):
|
|
|
|
try:
|
|
|
|
msg = self.format(record)
|
|
|
|
tqdm.write(msg)
|
|
|
|
self.flush()
|
2023-08-03 03:57:10 +00:00
|
|
|
except Exception as e:
|
|
|
|
self.handleError(record)
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
@global_leader_only
|
2024-09-06 04:21:18 +00:00
|
|
|
def setup_logging(log_dir: str | Path | None = None, log_level="info"):
|
2023-08-02 21:53:35 +00:00
|
|
|
handlers = []
|
|
|
|
|
|
|
|
#stdout_handler = StreamHandler()
|
|
|
|
stdout_handler = TqdmLoggingHandler()
|
|
|
|
stdout_handler.setLevel(logging.INFO)
|
|
|
|
formatter = ColoredFormatter(
|
|
|
|
f"%(asctime)s - %(name)s - %(levelname)s - GR={global_rank()};LR={local_rank()} - \n%(message)s"
|
|
|
|
)
|
|
|
|
stdout_handler.setFormatter(formatter)
|
|
|
|
handlers.append(stdout_handler)
|
|
|
|
|
|
|
|
if log_dir is not None:
|
|
|
|
filename = Path(log_dir) / f"log.txt"
|
|
|
|
filename.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
file_handler = logging.FileHandler(filename, mode="a")
|
|
|
|
file_handler.setLevel(logging.DEBUG)
|
|
|
|
handlers.append(file_handler)
|
|
|
|
|
2023-08-03 03:57:10 +00:00
|
|
|
|
2023-08-02 21:53:35 +00:00
|
|
|
logging.basicConfig(
|
|
|
|
level=logging.getLevelName(log_level.upper()),
|
|
|
|
format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s",
|
|
|
|
handlers=handlers,
|
|
|
|
)
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def tree_map(fn: Callable, x: list[T]) -> list[T]:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def tree_map(fn: Callable, x: T) -> T:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
def tree_map(fn: Callable, x):
|
|
|
|
if isinstance(x, list):
|
|
|
|
x = [tree_map(fn, xi) for xi in x]
|
|
|
|
elif isinstance(x, tuple):
|
|
|
|
x = (tree_map(fn, xi) for xi in x)
|
|
|
|
elif isinstance(x, dict):
|
|
|
|
x = {k: tree_map(fn, v) for k, v in x.items()}
|
|
|
|
elif isinstance(x, Tensor):
|
|
|
|
x = fn(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2024-07-23 00:38:39 +00:00
|
|
|
def to_device(x: T | None, *args, **kwargs) -> T:
|
2024-07-23 00:36:07 +00:00
|
|
|
if x is None:
|
|
|
|
return
|
|
|
|
|
2024-07-23 00:38:39 +00:00
|
|
|
return tree_map(lambda t: t.to(*args, **kwargs), x)
|
2024-07-27 20:36:05 +00:00
|
|
|
|
|
|
|
def coalese( *arg, return_last=True ):
|
2024-08-02 01:12:06 +00:00
|
|
|
return [ x for x in arg if x is not None ][-1 if return_last else 0]
|
|
|
|
|
|
|
|
# checks if a module name is within a given whitelist/blacklist policy dict
|
|
|
|
def passes_policy( policy, name ):
|
|
|
|
if policy is None:
|
|
|
|
return True
|
|
|
|
|
|
|
|
if "exclude" in policy:
|
|
|
|
for term in policy["exclude"]:
|
|
|
|
if term in name:
|
|
|
|
return False
|
|
|
|
|
|
|
|
if "include" in policy:
|
|
|
|
for term in policy["include"]:
|
|
|
|
if term in name:
|
|
|
|
return True
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
|
|
|
@contextmanager
|
|
|
|
def autocast(input, from_dtype, to_dtype):
|
|
|
|
if input.dtype == from_dtype:
|
|
|
|
input = input.to(to_dtype)
|
|
|
|
yield input
|
|
|
|
input = input.to(from_dtype)
|
|
|
|
else:
|
|
|
|
yield input
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def autocasts(input, from_dtype, to_dtype):
|
|
|
|
if input.dtype in from_dtype:
|
|
|
|
from_dtype = input.dtype
|
|
|
|
input = input.to(to_dtype)
|
|
|
|
yield input
|
|
|
|
input = input.to(from_dtype)
|
|
|
|
else:
|
|
|
|
yield input
|
|
|
|
|
|
|
|
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
|
|
|
def autocast_forward( func ):
|
|
|
|
def wrapper( self, input, *args, **kwargs ):
|
|
|
|
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
|
|
|
|
return func( self, k, *args, **kwargs )
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
# handles migrating an input tensor to a given devicve
|
2024-08-02 03:43:39 +00:00
|
|
|
def auto_align_inputs_forward( module, device=None, name = None ):
|
|
|
|
func = module.forward
|
|
|
|
|
2024-08-02 01:12:06 +00:00
|
|
|
if device is None:
|
2024-08-02 03:43:39 +00:00
|
|
|
if hasattr( module, 'device' ):
|
|
|
|
device = module.device
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
|
|
|
|
except Exception as e:
|
|
|
|
return func
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
|
2024-08-02 03:43:39 +00:00
|
|
|
def wrapper( *args, **kwargs ):
|
2024-08-02 01:12:06 +00:00
|
|
|
args = [*args]
|
2024-08-02 03:43:39 +00:00
|
|
|
# search through args and kwargs for any Tensor arguments
|
2024-08-02 01:12:06 +00:00
|
|
|
for i, arg in enumerate(args):
|
|
|
|
if not isinstance( arg, torch.Tensor ):
|
|
|
|
continue
|
|
|
|
args[i] = arg.to( device=device )
|
|
|
|
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
if not isinstance( v, torch.Tensor ):
|
|
|
|
continue
|
|
|
|
kwargs[k] = v.to( device=device )
|
|
|
|
|
2024-08-02 03:43:39 +00:00
|
|
|
# disgusting patch
|
|
|
|
if "position_embeddings" in kwargs:
|
|
|
|
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
|
|
|
|
|
|
|
|
return func( *args, **kwargs )
|
2024-08-02 01:12:06 +00:00
|
|
|
return wrapper
|
|
|
|
|
|
|
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
|
|
|
# generalizing this would be super sugoi but the there's no catch all for arguments
|
|
|
|
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
|
|
|
|
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
|
|
|
|
|
|
|
|
device = next(model.parameters()).device
|
|
|
|
dtype = next(model.parameters()).dtype
|
|
|
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
|
|
|
|
|
|
|
for *parent, k in modules:
|
|
|
|
name = '.'.join(parent)
|
|
|
|
|
|
|
|
m = getattr( model.get_submodule(name), k )
|
|
|
|
|
|
|
|
if isinstance(m, klass):
|
|
|
|
continue
|
|
|
|
|
|
|
|
kwargs = dict(
|
|
|
|
in_features = m.in_features,
|
|
|
|
out_features = m.out_features,
|
|
|
|
bias = m.bias is not None,
|
|
|
|
) if not bnb else dict(
|
|
|
|
input_features=m.in_features,
|
|
|
|
output_features=m.out_features,
|
|
|
|
bias=m.bias is not None,
|
|
|
|
)
|
|
|
|
|
|
|
|
# overwrite
|
|
|
|
setattr(
|
|
|
|
model.get_submodule(name), k,
|
|
|
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
|
|
|
)
|
|
|
|
|
|
|
|
if verbose:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
|
|
|
|
device = next(model.parameters()).device
|
|
|
|
dtype = next(model.parameters()).dtype
|
|
|
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
|
|
|
|
|
|
|
for *parent, k in modules:
|
|
|
|
name = '.'.join(parent)
|
|
|
|
|
|
|
|
m = getattr( model.get_submodule(name), k )
|
|
|
|
|
|
|
|
if isinstance(m, klass):
|
|
|
|
continue
|
|
|
|
|
|
|
|
kwargs = dict(
|
|
|
|
num_embeddings=m.num_embeddings,
|
|
|
|
embedding_dim=m.embedding_dim,
|
|
|
|
padding_idx=m.padding_idx,
|
|
|
|
max_norm=m.max_norm,
|
|
|
|
norm_type=m.norm_type,
|
|
|
|
scale_grad_by_freq=m.scale_grad_by_freq,
|
|
|
|
sparse=m.sparse,
|
|
|
|
)
|
|
|
|
|
|
|
|
# overwrite
|
|
|
|
setattr(
|
|
|
|
model.get_submodule(name), k,
|
|
|
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
|
|
|
)
|
|
|
|
|
|
|
|
if verbose:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
# cannot feasibly do default arguments here sad
|
|
|
|
def replace_attention( model, klass, target, mode="math", verbose=False ):
|
|
|
|
device = next(model.parameters()).device
|
|
|
|
dtype = next(model.parameters()).dtype
|
|
|
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
|
|
|
|
|
|
|
for *parent, k in modules:
|
|
|
|
name = '.'.join(parent)
|
|
|
|
|
|
|
|
m = getattr( model.get_submodule(name), k )
|
|
|
|
|
|
|
|
if isinstance(m, klass):
|
|
|
|
continue
|
|
|
|
|
|
|
|
kwargs = dict(
|
|
|
|
config = m.config,
|
|
|
|
layer_idx = m.layer_idx,
|
|
|
|
mode = mode,
|
|
|
|
)
|
|
|
|
# overwrite
|
|
|
|
setattr(
|
|
|
|
model.get_submodule(name), k,
|
|
|
|
klass( **kwargs ).to(device=device, dtype=dtype)
|
|
|
|
)
|
|
|
|
|
|
|
|
if verbose:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
# trim/expand a tensor (for example, in a state dict)
|
|
|
|
def resize_weight( weight, target, dim=0, random=True ):
|
|
|
|
# trim
|
|
|
|
if target < weight.shape[dim]:
|
|
|
|
return weight[:target]
|
|
|
|
# expand
|
|
|
|
if target > weight.shape[dim]:
|
|
|
|
fn = torch.rand if random else torch.zeros
|
|
|
|
return torch.stack(
|
|
|
|
[ x for x in weight ] +
|
|
|
|
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
|
|
|
|
)
|
|
|
|
|
|
|
|
return weight
|
|
|
|
|
2024-08-05 00:56:21 +00:00
|
|
|
def get_devices():
|
|
|
|
return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu']
|
|
|
|
|
2024-08-02 01:12:06 +00:00
|
|
|
# grabs the memory properties of a given device
|
|
|
|
def get_device_properties( device ):
|
|
|
|
if 'cuda' in device:
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
|
|
free, total = torch.cuda.mem_get_info(device)
|
|
|
|
else:
|
|
|
|
props = psutil.virtual_memory()
|
|
|
|
free, total = props.available, props.total
|
|
|
|
|
|
|
|
return {"name": device, "total": total, "free": free, "props": props}
|
|
|
|
|
|
|
|
# gets the rough size for a given module's parameters
|
|
|
|
def get_module_size( module ):
|
|
|
|
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
|
|
|
|
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
|
|
|
|
return param_size + buffer_size
|
|
|
|
|
2024-08-03 01:28:49 +00:00
|
|
|
# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way
|
|
|
|
# it'd be better to just attach to layers itself rather than every single module
|
|
|
|
|
2024-08-02 01:12:06 +00:00
|
|
|
# assigns modules to requested devices for a given policy
|
|
|
|
def get_model_offload_policy(module, policy=None):
|
|
|
|
# handle any other weird values this is set to
|
|
|
|
if not isinstance(policy, dict):
|
|
|
|
policy = {}
|
|
|
|
|
|
|
|
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
|
|
|
|
if "include" not in policy:
|
|
|
|
policy["include"] = ["model"]
|
2024-08-02 01:56:28 +00:00
|
|
|
|
2024-08-02 01:12:06 +00:00
|
|
|
if "limits" not in policy:
|
|
|
|
policy["limits"] = []
|
|
|
|
|
2024-08-02 01:56:28 +00:00
|
|
|
if "assign" not in policy:
|
|
|
|
policy["assign"] = []
|
|
|
|
|
2024-08-02 01:12:06 +00:00
|
|
|
if "devices" not in policy:
|
2024-08-05 00:56:21 +00:00
|
|
|
policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
# create initial device info
|
|
|
|
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
|
|
|
|
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
|
|
|
|
# filter
|
|
|
|
modules = [ (name, size) for name, size in modules if name and size ]
|
|
|
|
|
|
|
|
total_size = sum([size for name, size in modules])
|
|
|
|
|
|
|
|
# set caps if requested in the policy
|
|
|
|
for i, cap in enumerate(policy["limits"]):
|
|
|
|
# no limit, skip
|
|
|
|
if cap <= 0:
|
|
|
|
continue
|
|
|
|
# is fractional, scale to total size
|
|
|
|
if cap < 1:
|
|
|
|
cap = math.floor(total_size * cap)
|
|
|
|
# available space is below cap, don't set
|
|
|
|
if devices[i]["free"] < cap:
|
|
|
|
continue
|
|
|
|
# cap to requested size
|
|
|
|
devices[i]["free"] = cap
|
|
|
|
|
2024-08-02 01:56:28 +00:00
|
|
|
# assign if specific parts of the model are requested for assignment
|
|
|
|
if policy["assign"]:
|
|
|
|
discarded = []
|
|
|
|
# yuck, there has to be a better way
|
|
|
|
for device_index, includes in enumerate( policy["assign"] ):
|
|
|
|
device = devices[device_index]
|
|
|
|
|
|
|
|
buffered_modules = []
|
|
|
|
buffered_size = device["free"]
|
|
|
|
|
|
|
|
# iterate through list of modules to compare against includes
|
|
|
|
for name, size in modules:
|
|
|
|
# doesn't pass policy
|
|
|
|
if not passes_policy( {"include": includes}, name ):
|
|
|
|
continue
|
|
|
|
# check if within budget
|
|
|
|
if buffered_size - size >= 0:
|
|
|
|
# add to buffer
|
2024-08-02 03:56:04 +00:00
|
|
|
buffered_modules.append( (name, size) )
|
2024-08-02 01:56:28 +00:00
|
|
|
buffered_size -= size
|
|
|
|
# budget exceeded, flush buffer
|
|
|
|
else:
|
|
|
|
discarded += buffered_modules
|
|
|
|
buffered_modules = []
|
|
|
|
buffered_size = 0
|
|
|
|
break
|
|
|
|
|
|
|
|
if buffered_modules and buffered_size:
|
2024-08-02 03:56:04 +00:00
|
|
|
device["modules"] += [ name for name, size in buffered_modules ]
|
2024-08-02 01:56:28 +00:00
|
|
|
device["free"] = buffered_size
|
|
|
|
|
|
|
|
modules = discarded
|
|
|
|
|
2024-08-02 01:12:06 +00:00
|
|
|
device_index = 0
|
|
|
|
module_index = 0
|
2024-08-02 01:56:28 +00:00
|
|
|
# assign modules to each device
|
2024-08-02 01:12:06 +00:00
|
|
|
while module_index < len(modules) and device_index < len(devices):
|
|
|
|
device = devices[device_index]
|
|
|
|
name, size = modules[module_index]
|
|
|
|
|
|
|
|
# fits within budget
|
|
|
|
if device["free"] - size >= 0:
|
|
|
|
device["modules"].append( name )
|
|
|
|
device["free"] -= size
|
|
|
|
module_index += 1
|
|
|
|
# does not fit in budget, increase device index
|
|
|
|
else:
|
|
|
|
device_index += 1
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
# to-do: check that all modules are exhausted
|
|
|
|
assert module_index >= len(modules)
|
|
|
|
|
|
|
|
# only return devices with modules assigned
|
|
|
|
return [ device for device in devices if device["modules"] ]
|
|
|
|
|
|
|
|
# handles naively splitting a model's layers across multiple devices
|
|
|
|
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
|
|
|
|
def offload_model( model, policy=None ):
|
|
|
|
policy = get_model_offload_policy(model, policy=policy)
|
|
|
|
|
|
|
|
# move modules to respective devices
|
|
|
|
for i, device in enumerate( policy ):
|
|
|
|
# nothing assigned, skip
|
|
|
|
if not device["modules"]:
|
|
|
|
continue
|
|
|
|
|
|
|
|
for name in device["modules"]:
|
|
|
|
module = model.get_submodule(name)
|
|
|
|
module = module.to( device["name"] )
|
2024-08-02 03:43:39 +00:00
|
|
|
module.device = device['name']
|
2024-08-02 01:12:06 +00:00
|
|
|
|
2024-08-02 03:43:39 +00:00
|
|
|
# wrap modules with forward to ensure all inputs are matched to its device
|
|
|
|
for name, module in model.named_modules():
|
|
|
|
if not hasattr( module, 'forward' ):
|
2024-08-02 01:12:06 +00:00
|
|
|
continue
|
|
|
|
|
2024-08-02 03:43:39 +00:00
|
|
|
module.forward = auto_align_inputs_forward(module)
|
2024-08-02 01:12:06 +00:00
|
|
|
|
|
|
|
"""
|
|
|
|
# Validate that the layers are all in the right spot
|
|
|
|
for name, module in model.named_modules():
|
|
|
|
if not not [*module.named_children()]:
|
|
|
|
continue
|
|
|
|
try:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info( name, next(module.parameters()).device )
|
2024-08-02 01:12:06 +00:00
|
|
|
except Exception as e:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.info( name, "?" )
|
2024-08-02 01:12:06 +00:00
|
|
|
pass
|
|
|
|
"""
|
|
|
|
|
|
|
|
return model
|