""" # https://github.com/enhuiz/pytorch-training-utilities """ from .distributed import global_rank, local_rank, global_leader_only import gc import logging import pandas as pd import numpy as np import re import torch import random import time import psutil import math import logging import hashlib _logger = logging.getLogger(__name__) from coloredlogs import ColoredFormatter from logging import StreamHandler from pathlib import Path from torch import Tensor, nn from tqdm.auto import tqdm from typing import Callable, TypeVar, overload from contextlib import contextmanager from time import perf_counter from datetime import datetime T = TypeVar("T") # removes prefix from key in a dict # useful for mapping args like ar_temperature => temperature def convert_kwargs( kwargs, prefix ): copied = {} | kwargs for key, value in copied.items(): if not key.startswith( prefix ): continue kwargs.pop(key) kwargs[key[len(prefix):]] = value return kwargs # hashes values or a list of values def md5_hash( x ): if isinstance( x, list ): return md5_hash(":".join([ md5_hash( _ ) for _ in x ])) return hashlib.md5(str(x).encode("utf-8")).hexdigest() # removes entries from a dict if that key is missing from the source def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ): is_obj = hasattr( source, "__dict__" ) if parent_is_obj is None: parent_is_obj = is_obj haystack = source.__dict__ if is_obj else source keep = {} missing = [] for k, v in dest.items(): if k in haystack or (parent_is_obj and not is_obj and source == {}): keep[k] = dest[k] else: missing.append(".".join(path + [k])) if recurse and isinstance( v, dict ): keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing ) missing += m return (keep, missing) if return_missing else keep def clamp(n, lo, hi): return max(lo, min(n, hi)) class timer: def __init__(self, msg="Elapsed time:", callback=None): self.msg = msg self.callback = callback def __enter__(self): self.start = perf_counter() return self def __exit__(self, type, value, traceback): msg = f'{self.msg} {(perf_counter() - self.start):.9f}s' if self.callback: self.callback(msg) print(f'[{datetime.now().isoformat()}] {msg}') def truncate_json( x ): if isinstance( x, bytes ): return truncate_json( x.decode('utf-8') ).encode() def fun( match ): return "{:.4f}".format(float(match.group())) return re.sub(r"\d+\.\d{8,}", fun, x) def do_gc(): gc.collect() torch.cuda.empty_cache() def flatten_dict(d): records = pd.json_normalize(d).to_dict(orient="records") return records[0] if records else {} def set_seed(seed=None): if not seed: seed = int(time.time()) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) return seed def _get_named_modules(module, attrname): for name, module in module.named_modules(): if hasattr(module, attrname): yield name, module def gather_attribute(module, attrname, delete=True, prefix=True): ret = {} for name, module in _get_named_modules(module, attrname): ret[name] = getattr(module, attrname) if delete: try: delattr(module, attrname) except Exception as e: raise RuntimeError(f"{name} {module} {attrname}") from e if prefix: ret = {attrname: ret} ret = flatten_dict(ret) # remove consecutive dots ret = {re.sub(r"\.+", ".", k): v for k, v in ret.items()} return ret def dispatch_attribute( module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None, ): for _, module in _get_named_modules(module, attrname): if filter_fn is None or filter_fn(module): setattr(module, attrname, value) def load_state_dict_non_strict(model, state_dict, logger=None): model_state_dict = model.state_dict() provided = set(state_dict) required = set(model_state_dict) agreed = provided & required for k in list(agreed): if model_state_dict[k].shape != state_dict[k].shape: agreed.remove(k) provided.remove(k) state_dict = {k: state_dict[k] for k in agreed} if logger is not None and (diff := provided - required): logger.warning( f"Extra parameters are found. " f"Provided but not required parameters: \n{diff}." ) if logger is not None and (diff := required - provided): logger.warning( f"Some parameters are missing. " f"Required but not provided parameters: \n{diff}." ) model.load_state_dict(state_dict, strict=False) class TqdmLoggingHandler(logging.Handler): def __init__(self, level=logging.INFO): super().__init__(level) def emit(self, record): try: msg = self.format(record) tqdm.write(msg) self.flush() except Exception as e: self.handleError(record) @global_leader_only def setup_logging(log_dir: str | Path | None = None, log_level="info"): handlers = [] #stdout_handler = StreamHandler() stdout_handler = TqdmLoggingHandler() stdout_handler.setLevel(logging.INFO) formatter = ColoredFormatter( f"%(asctime)s - %(name)s - %(levelname)s - GR={global_rank()};LR={local_rank()} - \n%(message)s" ) stdout_handler.setFormatter(formatter) handlers.append(stdout_handler) if log_dir is not None: filename = Path(log_dir) / f"log.txt" filename.parent.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(filename, mode="a") file_handler.setLevel(logging.DEBUG) handlers.append(file_handler) logging.basicConfig( level=logging.getLevelName(log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s", handlers=handlers, ) @overload def tree_map(fn: Callable, x: list[T]) -> list[T]: ... @overload def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]: ... @overload def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]: ... @overload def tree_map(fn: Callable, x: T) -> T: ... def tree_map(fn: Callable, x): if isinstance(x, list): x = [tree_map(fn, xi) for xi in x] elif isinstance(x, tuple): x = (tree_map(fn, xi) for xi in x) elif isinstance(x, dict): x = {k: tree_map(fn, v) for k, v in x.items()} elif isinstance(x, Tensor): x = fn(x) return x def to_device(x: T | None, *args, **kwargs) -> T: if x is None: return return tree_map(lambda t: t.to(*args, **kwargs), x) def coalese( *arg, return_last=True ): return [ x for x in arg if x is not None ][-1 if return_last else 0] # checks if a module name is within a given whitelist/blacklist policy dict def passes_policy( policy, name ): if policy is None: return True if "exclude" in policy: for term in policy["exclude"]: if term in name: return False if "include" in policy: for term in policy["include"]: if term in name: return True return False # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) @contextmanager def autocast(input, from_dtype, to_dtype): if input.dtype == from_dtype: input = input.to(to_dtype) yield input input = input.to(from_dtype) else: yield input @contextmanager def autocasts(input, from_dtype, to_dtype): if input.dtype in from_dtype: from_dtype = input.dtype input = input.to(to_dtype) yield input input = input.to(from_dtype) else: yield input # handles temporarily upcasting 'index tensors' so torch will stop bitching def autocast_forward( func ): def wrapper( self, input, *args, **kwargs ): with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k: return func( self, k, *args, **kwargs ) return wrapper # handles migrating an input tensor to a given devicve def auto_align_inputs_forward( module, device=None, name = None ): func = module.forward if device is None: if hasattr( module, 'device' ): device = module.device else: try: device = next(module.parameters() if [*module.parameters()] else module.buffers()).device except Exception as e: return func def wrapper( *args, **kwargs ): args = [*args] # search through args and kwargs for any Tensor arguments for i, arg in enumerate(args): if not isinstance( arg, torch.Tensor ): continue args[i] = arg.to( device=device ) for k, v in kwargs.items(): if not isinstance( v, torch.Tensor ): continue kwargs[k] = v.to( device=device ) # disgusting patch if "position_embeddings" in kwargs: kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ]) return func( *args, **kwargs ) return wrapper # disgusting kludge, but it works (just realized BitNet has its own replacement routine) # generalizing this would be super sugoi but the there's no catch all for arguments def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ): bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( in_features = m.in_features, out_features = m.out_features, bias = m.bias is not None, ) if not bnb else dict( input_features=m.in_features, output_features=m.out_features, bias=m.bias is not None, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: _logger.info(f"Replacing {name}.{k} to: {klass}") return model def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( num_embeddings=m.num_embeddings, embedding_dim=m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, scale_grad_by_freq=m.scale_grad_by_freq, sparse=m.sparse, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: _logger.info(f"Replacing {name}.{k} to: {klass}") return model # cannot feasibly do default arguments here sad def replace_attention( model, klass, target, mode="math", verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( config = m.config, layer_idx = m.layer_idx, mode = mode, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: _logger.info(f"Replacing {name}.{k} to: {klass}") return model # trim/expand a tensor (for example, in a state dict) def resize_weight( weight, target, dim=0, random=True ): # trim if target < weight.shape[dim]: return weight[:target] # expand if target > weight.shape[dim]: fn = torch.rand if random else torch.zeros return torch.stack( [ x for x in weight ] + [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ] ) return weight def get_devices(): return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # grabs the memory properties of a given device def get_device_properties( device ): if 'cuda' in device: props = torch.cuda.get_device_properties(device) free, total = torch.cuda.mem_get_info(device) else: props = psutil.virtual_memory() free, total = props.available, props.total return {"name": device, "total": total, "free": free, "props": props} # gets the rough size for a given module's parameters def get_module_size( module ): param_size = sum([p.nelement() * p.element_size() for p in module.parameters()]) buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()]) return param_size + buffer_size # to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way # it'd be better to just attach to layers itself rather than every single module # assigns modules to requested devices for a given policy def get_model_offload_policy(module, policy=None): # handle any other weird values this is set to if not isinstance(policy, dict): policy = {} # default to only include the core model, and not the other modules (embeddings) in the splitting policy if "include" not in policy: policy["include"] = ["model"] if "limits" not in policy: policy["limits"] = [] if "assign" not in policy: policy["assign"] = [] if "devices" not in policy: policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget # create initial device info devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ] modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ] # filter modules = [ (name, size) for name, size in modules if name and size ] total_size = sum([size for name, size in modules]) # set caps if requested in the policy for i, cap in enumerate(policy["limits"]): # no limit, skip if cap <= 0: continue # is fractional, scale to total size if cap < 1: cap = math.floor(total_size * cap) # available space is below cap, don't set if devices[i]["free"] < cap: continue # cap to requested size devices[i]["free"] = cap # assign if specific parts of the model are requested for assignment if policy["assign"]: discarded = [] # yuck, there has to be a better way for device_index, includes in enumerate( policy["assign"] ): device = devices[device_index] buffered_modules = [] buffered_size = device["free"] # iterate through list of modules to compare against includes for name, size in modules: # doesn't pass policy if not passes_policy( {"include": includes}, name ): continue # check if within budget if buffered_size - size >= 0: # add to buffer buffered_modules.append( (name, size) ) buffered_size -= size # budget exceeded, flush buffer else: discarded += buffered_modules buffered_modules = [] buffered_size = 0 break if buffered_modules and buffered_size: device["modules"] += [ name for name, size in buffered_modules ] device["free"] = buffered_size modules = discarded device_index = 0 module_index = 0 # assign modules to each device while module_index < len(modules) and device_index < len(devices): device = devices[device_index] name, size = modules[module_index] # fits within budget if device["free"] - size >= 0: device["modules"].append( name ) device["free"] -= size module_index += 1 # does not fit in budget, increase device index else: device_index += 1 _logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB") # to-do: check that all modules are exhausted assert module_index >= len(modules) # only return devices with modules assigned return [ device for device in devices if device["modules"] ] # handles naively splitting a model's layers across multiple devices # this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU def offload_model( model, policy=None ): policy = get_model_offload_policy(model, policy=policy) # move modules to respective devices for i, device in enumerate( policy ): # nothing assigned, skip if not device["modules"]: continue for name in device["modules"]: module = model.get_submodule(name) module = module.to( device["name"] ) module.device = device['name'] # wrap modules with forward to ensure all inputs are matched to its device for name, module in model.named_modules(): if not hasattr( module, 'forward' ): continue module.forward = auto_align_inputs_forward(module) """ # Validate that the layers are all in the right spot for name, module in model.named_modules(): if not not [*module.named_children()]: continue try: _logger.info( name, next(module.parameters()).device ) except Exception as e: _logger.info( name, "?" ) pass """ return model