vall-e/vall_e/utils/utils.py

608 lines
17 KiB
Python
Executable File

"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
from .distributed import global_rank, local_rank, global_leader_only
import gc
import logging
import pandas as pd
import numpy as np
import re
import torch
import random
import time
import psutil
import math
import logging
import hashlib
_logger = logging.getLogger(__name__)
from coloredlogs import ColoredFormatter
from logging import StreamHandler
from pathlib import Path
from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
from contextlib import contextmanager
from time import perf_counter
from datetime import datetime
T = TypeVar("T")
# removes prefix from key in a dict
# useful for mapping args like ar_temperature => temperature
def convert_kwargs( kwargs, prefix ):
copied = {} | kwargs
for key, value in copied.items():
if not key.startswith( prefix ):
continue
kwargs.pop(key)
kwargs[key[len(prefix):]] = value
return kwargs
# hashes values or a list of values
def md5_hash( x ):
if isinstance( x, list ):
return md5_hash(":".join([ md5_hash( _ ) for _ in x ]))
return hashlib.md5(str(x).encode("utf-8")).hexdigest()
# removes entries from a dict if that key is missing from the source
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
is_obj = hasattr( source, "__dict__" )
if parent_is_obj is None:
parent_is_obj = is_obj
haystack = source.__dict__ if is_obj else source
keep = {}
missing = []
for k, v in dest.items():
if k in haystack or (parent_is_obj and not is_obj and source == {}):
keep[k] = dest[k]
else:
missing.append(".".join(path + [k]))
if recurse and isinstance( v, dict ):
keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing )
missing += m
return (keep, missing) if return_missing else keep
def clamp(n, lo, hi):
return max(lo, min(n, hi))
class timer:
def __init__(self, msg="Elapsed time:", callback=None):
self.msg = msg
self.callback = callback
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.9f}s'
if self.callback:
self.callback(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
def truncate_json( x ):
if isinstance( x, bytes ):
return truncate_json( x.decode('utf-8') ).encode()
def fun( match ):
return "{:.4f}".format(float(match.group()))
return re.sub(r"\d+\.\d{8,}", fun, x)
def do_gc():
gc.collect()
torch.cuda.empty_cache()
def flatten_dict(d):
records = pd.json_normalize(d).to_dict(orient="records")
return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return seed
def _get_named_modules(module, attrname):
for name, module in module.named_modules():
if hasattr(module, attrname):
yield name, module
def gather_attribute(module, attrname, delete=True, prefix=True):
ret = {}
for name, module in _get_named_modules(module, attrname):
ret[name] = getattr(module, attrname)
if delete:
try:
delattr(module, attrname)
except Exception as e:
raise RuntimeError(f"{name} {module} {attrname}") from e
if prefix:
ret = {attrname: ret}
ret = flatten_dict(ret)
# remove consecutive dots
ret = {re.sub(r"\.+", ".", k): v for k, v in ret.items()}
return ret
def dispatch_attribute(
module,
attrname,
value,
filter_fn: Callable[[nn.Module], bool] | None = None,
):
for _, module in _get_named_modules(module, attrname):
if filter_fn is None or filter_fn(module):
setattr(module, attrname, value)
def load_state_dict_non_strict(model, state_dict, logger=None):
model_state_dict = model.state_dict()
provided = set(state_dict)
required = set(model_state_dict)
agreed = provided & required
for k in list(agreed):
if model_state_dict[k].shape != state_dict[k].shape:
agreed.remove(k)
provided.remove(k)
state_dict = {k: state_dict[k] for k in agreed}
if logger is not None and (diff := provided - required):
logger.warning(
f"Extra parameters are found. "
f"Provided but not required parameters: \n{diff}."
)
if logger is not None and (diff := required - provided):
logger.warning(
f"Some parameters are missing. "
f"Required but not provided parameters: \n{diff}."
)
model.load_state_dict(state_dict, strict=False)
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception as e:
self.handleError(record)
@global_leader_only
def setup_logging(log_dir: str | Path | None = None, log_level="info"):
handlers = []
#stdout_handler = StreamHandler()
stdout_handler = TqdmLoggingHandler()
stdout_handler.setLevel(logging.INFO)
formatter = ColoredFormatter(
f"%(asctime)s - %(name)s - %(levelname)s - GR={global_rank()};LR={local_rank()} - \n%(message)s"
)
stdout_handler.setFormatter(formatter)
handlers.append(stdout_handler)
if log_dir is not None:
filename = Path(log_dir) / f"log.txt"
filename.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(filename, mode="a")
file_handler.setLevel(logging.DEBUG)
handlers.append(file_handler)
logging.basicConfig(
level=logging.getLevelName(log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s",
handlers=handlers,
)
@overload
def tree_map(fn: Callable, x: list[T]) -> list[T]:
...
@overload
def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
...
@overload
def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
...
@overload
def tree_map(fn: Callable, x: T) -> T:
...
def tree_map(fn: Callable, x):
if isinstance(x, list):
x = [tree_map(fn, xi) for xi in x]
elif isinstance(x, tuple):
x = (tree_map(fn, xi) for xi in x)
elif isinstance(x, dict):
x = {k: tree_map(fn, v) for k, v in x.items()}
elif isinstance(x, Tensor):
x = fn(x)
return x
def to_device(x: T | None, *args, **kwargs) -> T:
if x is None:
return
return tree_map(lambda t: t.to(*args, **kwargs), x)
def coalese( *arg, return_last=True ):
return [ x for x in arg if x is not None ][-1 if return_last else 0]
# checks if a module name is within a given whitelist/blacklist policy dict
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
return False
if "include" in policy:
for term in policy["include"]:
if term in name:
return True
return False
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
return wrapper
# handles migrating an input tensor to a given devicve
def auto_align_inputs_forward( module, device=None, name = None ):
func = module.forward
if device is None:
if hasattr( module, 'device' ):
device = module.device
else:
try:
device = next(module.parameters() if [*module.parameters()] else module.buffers()).device
except Exception as e:
return func
def wrapper( *args, **kwargs ):
args = [*args]
# search through args and kwargs for any Tensor arguments
for i, arg in enumerate(args):
if not isinstance( arg, torch.Tensor ):
continue
args[i] = arg.to( device=device )
for k, v in kwargs.items():
if not isinstance( v, torch.Tensor ):
continue
kwargs[k] = v.to( device=device )
# disgusting patch
if "position_embeddings" in kwargs:
kwargs["position_embeddings"] = tuple([ t.to(device=device) for t in kwargs["position_embeddings"] ])
return func( *args, **kwargs )
return wrapper
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight
def get_devices():
return [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu']
# grabs the memory properties of a given device
def get_device_properties( device ):
if 'cuda' in device:
props = torch.cuda.get_device_properties(device)
free, total = torch.cuda.mem_get_info(device)
else:
props = psutil.virtual_memory()
free, total = props.available, props.total
return {"name": device, "total": total, "free": free, "props": props}
# gets the rough size for a given module's parameters
def get_module_size( module ):
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
return param_size + buffer_size
# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way
# it'd be better to just attach to layers itself rather than every single module
# assigns modules to requested devices for a given policy
def get_model_offload_policy(module, policy=None):
# handle any other weird values this is set to
if not isinstance(policy, dict):
policy = {}
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
if "include" not in policy:
policy["include"] = ["model"]
if "limits" not in policy:
policy["limits"] = []
if "assign" not in policy:
policy["assign"] = []
if "devices" not in policy:
policy["devices"] = get_devices() # + cpu to spill the remainder on CPU if overbudget
# create initial device info
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
# filter
modules = [ (name, size) for name, size in modules if name and size ]
total_size = sum([size for name, size in modules])
# set caps if requested in the policy
for i, cap in enumerate(policy["limits"]):
# no limit, skip
if cap <= 0:
continue
# is fractional, scale to total size
if cap < 1:
cap = math.floor(total_size * cap)
# available space is below cap, don't set
if devices[i]["free"] < cap:
continue
# cap to requested size
devices[i]["free"] = cap
# assign if specific parts of the model are requested for assignment
if policy["assign"]:
discarded = []
# yuck, there has to be a better way
for device_index, includes in enumerate( policy["assign"] ):
device = devices[device_index]
buffered_modules = []
buffered_size = device["free"]
# iterate through list of modules to compare against includes
for name, size in modules:
# doesn't pass policy
if not passes_policy( {"include": includes}, name ):
continue
# check if within budget
if buffered_size - size >= 0:
# add to buffer
buffered_modules.append( (name, size) )
buffered_size -= size
# budget exceeded, flush buffer
else:
discarded += buffered_modules
buffered_modules = []
buffered_size = 0
break
if buffered_modules and buffered_size:
device["modules"] += [ name for name, size in buffered_modules ]
device["free"] = buffered_size
modules = discarded
device_index = 0
module_index = 0
# assign modules to each device
while module_index < len(modules) and device_index < len(devices):
device = devices[device_index]
name, size = modules[module_index]
# fits within budget
if device["free"] - size >= 0:
device["modules"].append( name )
device["free"] -= size
module_index += 1
# does not fit in budget, increase device index
else:
device_index += 1
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
# to-do: check that all modules are exhausted
assert module_index >= len(modules)
# only return devices with modules assigned
return [ device for device in devices if device["modules"] ]
# handles naively splitting a model's layers across multiple devices
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
def offload_model( model, policy=None ):
policy = get_model_offload_policy(model, policy=policy)
# move modules to respective devices
for i, device in enumerate( policy ):
# nothing assigned, skip
if not device["modules"]:
continue
for name in device["modules"]:
module = model.get_submodule(name)
module = module.to( device["name"] )
module.device = device['name']
# wrap modules with forward to ensure all inputs are matched to its device
for name, module in model.named_modules():
if not hasattr( module, 'forward' ):
continue
module.forward = auto_align_inputs_forward(module)
"""
# Validate that the layers are all in the right spot
for name, module in model.named_modules():
if not not [*module.named_children()]:
continue
try:
_logger.info( name, next(module.parameters()).device )
except Exception as e:
_logger.info( name, "?" )
pass
"""
return model