""" # https://github.com/enhuiz/pytorch-training-utilities """ from .distributed import global_rank, local_rank, global_leader_only import gc import logging import pandas as pd import numpy as np import re import torch import random import time import psutil import math from coloredlogs import ColoredFormatter from logging import StreamHandler from pathlib import Path from torch import Tensor, nn from tqdm.auto import tqdm from typing import Callable, TypeVar, overload from contextlib import contextmanager T = TypeVar("T") def truncate_json( str ): def fun( match ): return "{:.4f}".format(float(match.group())) return re.sub(r"\d+\.\d{8,}", fun, str) def do_gc(): gc.collect() torch.cuda.empty_cache() def flatten_dict(d): records = pd.json_normalize(d).to_dict(orient="records") return records[0] if records else {} def set_seed(seed=None): if not seed: seed = int(time.time()) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def _get_named_modules(module, attrname): for name, module in module.named_modules(): if hasattr(module, attrname): yield name, module def gather_attribute(module, attrname, delete=True, prefix=True): ret = {} for name, module in _get_named_modules(module, attrname): ret[name] = getattr(module, attrname) if delete: try: delattr(module, attrname) except Exception as e: raise RuntimeError(f"{name} {module} {attrname}") from e if prefix: ret = {attrname: ret} ret = flatten_dict(ret) # remove consecutive dots ret = {re.sub(r"\.+", ".", k): v for k, v in ret.items()} return ret def dispatch_attribute( module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None, ): for _, module in _get_named_modules(module, attrname): if filter_fn is None or filter_fn(module): setattr(module, attrname, value) def load_state_dict_non_strict(model, state_dict, logger=None): model_state_dict = model.state_dict() provided = set(state_dict) required = set(model_state_dict) agreed = provided & required for k in list(agreed): if model_state_dict[k].shape != state_dict[k].shape: agreed.remove(k) provided.remove(k) state_dict = {k: state_dict[k] for k in agreed} if logger is not None and (diff := provided - required): logger.warning( f"Extra parameters are found. " f"Provided but not required parameters: \n{diff}." ) if logger is not None and (diff := required - provided): logger.warning( f"Some parameters are missing. " f"Required but not provided parameters: \n{diff}." ) model.load_state_dict(state_dict, strict=False) class TqdmLoggingHandler(logging.Handler): def __init__(self, level=logging.INFO): super().__init__(level) def emit(self, record): try: msg = self.format(record) tqdm.write(msg) self.flush() except Exception as e: self.handleError(record) @global_leader_only def setup_logging(log_dir: str | Path | None = "log", log_level="info"): handlers = [] #stdout_handler = StreamHandler() stdout_handler = TqdmLoggingHandler() stdout_handler.setLevel(logging.INFO) formatter = ColoredFormatter( f"%(asctime)s - %(name)s - %(levelname)s - GR={global_rank()};LR={local_rank()} - \n%(message)s" ) stdout_handler.setFormatter(formatter) handlers.append(stdout_handler) if log_dir is not None: filename = Path(log_dir) / f"log.txt" filename.parent.mkdir(parents=True, exist_ok=True) file_handler = logging.FileHandler(filename, mode="a") file_handler.setLevel(logging.DEBUG) handlers.append(file_handler) logging.basicConfig( level=logging.getLevelName(log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s", handlers=handlers, ) @overload def tree_map(fn: Callable, x: list[T]) -> list[T]: ... @overload def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]: ... @overload def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]: ... @overload def tree_map(fn: Callable, x: T) -> T: ... def tree_map(fn: Callable, x): if isinstance(x, list): x = [tree_map(fn, xi) for xi in x] elif isinstance(x, tuple): x = (tree_map(fn, xi) for xi in x) elif isinstance(x, dict): x = {k: tree_map(fn, v) for k, v in x.items()} elif isinstance(x, Tensor): x = fn(x) return x def to_device(x: T | None, *args, **kwargs) -> T: if x is None: return return tree_map(lambda t: t.to(*args, **kwargs), x) def coalese( *arg, return_last=True ): return [ x for x in arg if x is not None ][-1 if return_last else 0] # checks if a module name is within a given whitelist/blacklist policy dict def passes_policy( policy, name ): if policy is None: return True if "exclude" in policy: for term in policy["exclude"]: if term in name: return False if "include" in policy: for term in policy["include"]: if term in name: return True return False # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) @contextmanager def autocast(input, from_dtype, to_dtype): if input.dtype == from_dtype: input = input.to(to_dtype) yield input input = input.to(from_dtype) else: yield input @contextmanager def autocasts(input, from_dtype, to_dtype): if input.dtype in from_dtype: from_dtype = input.dtype input = input.to(to_dtype) yield input input = input.to(from_dtype) else: yield input # handles temporarily upcasting 'index tensors' so torch will stop bitching def autocast_forward( func ): def wrapper( self, input, *args, **kwargs ): with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k: return func( self, k, *args, **kwargs ) return wrapper # handles migrating an input tensor to a given devicve def auto_to_forward( module, device=None ): if device is None: device = next(module.parameters()).device func = module.forward def wrapper( self, *args, **kwargs ): # search through args and kwargs for any Tensor arguments args = [*args] for i, arg in enumerate(args): if not isinstance( arg, torch.Tensor ): continue args[i] = arg.to( device=device ) for k, v in kwargs.items(): if not isinstance( v, torch.Tensor ): continue kwargs[k] = v.to( device=device ) return func( self, *args, **kwargs ) return wrapper # disgusting kludge, but it works (just realized BitNet has its own replacement routine) # generalizing this would be super sugoi but the there's no catch all for arguments def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ): bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( in_features = m.in_features, out_features = m.out_features, bias = m.bias is not None, ) if not bnb else dict( input_features=m.in_features, output_features=m.out_features, bias=m.bias is not None, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: print(f"Replacing {name}.{k} to", klass) return model def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( num_embeddings=m.num_embeddings, embedding_dim=m.embedding_dim, padding_idx=m.padding_idx, max_norm=m.max_norm, norm_type=m.norm_type, scale_grad_by_freq=m.scale_grad_by_freq, sparse=m.sparse, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: print(f"Replacing {name}.{k} to", klass) return model # cannot feasibly do default arguments here sad def replace_attention( model, klass, target, mode="math", verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] for *parent, k in modules: name = '.'.join(parent) m = getattr( model.get_submodule(name), k ) if isinstance(m, klass): continue kwargs = dict( config = m.config, layer_idx = m.layer_idx, mode = mode, ) # overwrite setattr( model.get_submodule(name), k, klass( **kwargs ).to(device=device, dtype=dtype) ) if verbose: print(f"Replacing {name}.{k} to", klass) return model # trim/expand a tensor (for example, in a state dict) def resize_weight( weight, target, dim=0, random=True ): # trim if target < weight.shape[dim]: return weight[:target] # expand if target > weight.shape[dim]: fn = torch.rand if random else torch.zeros return torch.stack( [ x for x in weight ] + [ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ] ) return weight # grabs the memory properties of a given device def get_device_properties( device ): if 'cuda' in device: props = torch.cuda.get_device_properties(device) free, total = torch.cuda.mem_get_info(device) else: props = psutil.virtual_memory() free, total = props.available, props.total return {"name": device, "total": total, "free": free, "props": props} # gets the rough size for a given module's parameters def get_module_size( module ): param_size = sum([p.nelement() * p.element_size() for p in module.parameters()]) buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()]) return param_size + buffer_size # assigns modules to requested devices for a given policy def get_model_offload_policy(module, policy=None): # handle any other weird values this is set to if not isinstance(policy, dict): policy = {} # default to only include the core model, and not the other modules (embeddings) in the splitting policy if "include" not in policy: policy["include"] = ["model"] if "limits" not in policy: policy["limits"] = [] if "devices" not in policy: policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget # create initial device info devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ] modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ] # filter modules = [ (name, size) for name, size in modules if name and size ] total_size = sum([size for name, size in modules]) # set caps if requested in the policy for i, cap in enumerate(policy["limits"]): # no limit, skip if cap <= 0: continue # is fractional, scale to total size if cap < 1: cap = math.floor(total_size * cap) # available space is below cap, don't set if devices[i]["free"] < cap: continue # cap to requested size devices[i]["free"] = cap device_index = 0 module_index = 0 while module_index < len(modules) and device_index < len(devices): device = devices[device_index] name, size = modules[module_index] # fits within budget if device["free"] - size >= 0: device["modules"].append( name ) device["free"] -= size module_index += 1 # does not fit in budget, increase device index else: device_index += 1 print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB") # to-do: check that all modules are exhausted assert module_index >= len(modules) # only return devices with modules assigned return [ device for device in devices if device["modules"] ] # handles naively splitting a model's layers across multiple devices # this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU def offload_model( model, policy=None ): policy = get_model_offload_policy(model, policy=policy) # move modules to respective devices for i, device in enumerate( policy ): # nothing assigned, skip if not device["modules"]: continue for name in device["modules"]: module = model.get_submodule(name) module = module.to( device["name"] ) """ # in case the above doesn't actually do what's requested *parent, key = name.split(".") module = getattr( model.get_submodule(name), key ) setattr( model.get_submodule(name), key, module.to( device["name"] ) ) """ # select next device to cast inputs to, or wrap to first if last device next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"] # same device, don't bother wrapping if device["name"] == next_device: continue # wrap forward call last_module = model.get_submodule( device["modules"][-1] ) last_module.forward = auto_to_forward(last_module, next_device) """ # Validate that the layers are all in the right spot for name, module in model.named_modules(): if not not [*module.named_children()]: continue try: print( name, next(module.parameters()).device ) except Exception as e: pass """ return model