vall-e/vall_e/utils/utils.py

490 lines
14 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
from .distributed import global_rank, local_rank, global_leader_only
import gc
import logging
import pandas as pd
import numpy as np
2023-08-02 21:53:35 +00:00
import re
import torch
import random
import time
import psutil
import math
2023-08-02 21:53:35 +00:00
from coloredlogs import ColoredFormatter
from logging import StreamHandler
from pathlib import Path
from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
from contextlib import contextmanager
2023-08-02 21:53:35 +00:00
T = TypeVar("T")
def truncate_json( str ):
def fun( match ):
return "{:.4f}".format(float(match.group()))
return re.sub(r"\d+\.\d{8,}", fun, str)
2023-08-02 21:53:35 +00:00
def do_gc():
gc.collect()
torch.cuda.empty_cache()
def flatten_dict(d):
records = pd.json_normalize(d).to_dict(orient="records")
return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
2023-08-02 21:53:35 +00:00
def _get_named_modules(module, attrname):
for name, module in module.named_modules():
if hasattr(module, attrname):
yield name, module
def gather_attribute(module, attrname, delete=True, prefix=True):
ret = {}
for name, module in _get_named_modules(module, attrname):
ret[name] = getattr(module, attrname)
if delete:
try:
delattr(module, attrname)
except Exception as e:
raise RuntimeError(f"{name} {module} {attrname}") from e
if prefix:
ret = {attrname: ret}
ret = flatten_dict(ret)
# remove consecutive dots
ret = {re.sub(r"\.+", ".", k): v for k, v in ret.items()}
return ret
def dispatch_attribute(
module,
attrname,
value,
filter_fn: Callable[[nn.Module], bool] | None = None,
):
for _, module in _get_named_modules(module, attrname):
if filter_fn is None or filter_fn(module):
setattr(module, attrname, value)
def load_state_dict_non_strict(model, state_dict, logger=None):
model_state_dict = model.state_dict()
provided = set(state_dict)
required = set(model_state_dict)
agreed = provided & required
for k in list(agreed):
if model_state_dict[k].shape != state_dict[k].shape:
agreed.remove(k)
provided.remove(k)
state_dict = {k: state_dict[k] for k in agreed}
if logger is not None and (diff := provided - required):
logger.warning(
f"Extra parameters are found. "
f"Provided but not required parameters: \n{diff}."
)
if logger is not None and (diff := required - provided):
logger.warning(
f"Some parameters are missing. "
f"Required but not provided parameters: \n{diff}."
)
model.load_state_dict(state_dict, strict=False)
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
2023-08-02 21:53:35 +00:00
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception as e:
self.handleError(record)
2023-08-02 21:53:35 +00:00
@global_leader_only
def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
handlers = []
#stdout_handler = StreamHandler()
stdout_handler = TqdmLoggingHandler()
stdout_handler.setLevel(logging.INFO)
formatter = ColoredFormatter(
f"%(asctime)s - %(name)s - %(levelname)s - GR={global_rank()};LR={local_rank()} - \n%(message)s"
)
stdout_handler.setFormatter(formatter)
handlers.append(stdout_handler)
if log_dir is not None:
filename = Path(log_dir) / f"log.txt"
filename.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(filename, mode="a")
file_handler.setLevel(logging.DEBUG)
handlers.append(file_handler)
2023-08-02 21:53:35 +00:00
logging.basicConfig(
level=logging.getLevelName(log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s",
handlers=handlers,
)
@overload
def tree_map(fn: Callable, x: list[T]) -> list[T]:
...
@overload
def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
...
@overload
def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
...
@overload
def tree_map(fn: Callable, x: T) -> T:
...
def tree_map(fn: Callable, x):
if isinstance(x, list):
x = [tree_map(fn, xi) for xi in x]
elif isinstance(x, tuple):
x = (tree_map(fn, xi) for xi in x)
elif isinstance(x, dict):
x = {k: tree_map(fn, v) for k, v in x.items()}
elif isinstance(x, Tensor):
x = fn(x)
return x
2024-07-23 00:38:39 +00:00
def to_device(x: T | None, *args, **kwargs) -> T:
if x is None:
return
2024-07-23 00:38:39 +00:00
return tree_map(lambda t: t.to(*args, **kwargs), x)
def coalese( *arg, return_last=True ):
return [ x for x in arg if x is not None ][-1 if return_last else 0]
# checks if a module name is within a given whitelist/blacklist policy dict
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
return False
if "include" in policy:
for term in policy["include"]:
if term in name:
return True
return False
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
def autocast(input, from_dtype, to_dtype):
if input.dtype == from_dtype:
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8, torch.float16, torch.bfloat16], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
return wrapper
# handles migrating an input tensor to a given devicve
def auto_to_forward( module, device=None ):
if device is None:
device = next(module.parameters()).device
func = module.forward
def wrapper( self, *args, **kwargs ):
# search through args and kwargs for any Tensor arguments
args = [*args]
for i, arg in enumerate(args):
if not isinstance( arg, torch.Tensor ):
continue
args[i] = arg.to( device=device )
for k, v in kwargs.items():
if not isinstance( v, torch.Tensor ):
continue
kwargs[k] = v.to( device=device )
return func( self, *args, **kwargs )
return wrapper
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
# generalizing this would be super sugoi but the there's no catch all for arguments
def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
bnb = cfg.optimizations.bitsandbytes and cfg.optimizations.linear and not cfg.optimizations.bitnet
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
in_features = m.in_features,
out_features = m.out_features,
bias = m.bias is not None,
) if not bnb else dict(
input_features=m.in_features,
output_features=m.out_features,
bias=m.bias is not None,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
num_embeddings=m.num_embeddings,
embedding_dim=m.embedding_dim,
padding_idx=m.padding_idx,
max_norm=m.max_norm,
norm_type=m.norm_type,
scale_grad_by_freq=m.scale_grad_by_freq,
sparse=m.sparse,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
m = getattr( model.get_submodule(name), k )
if isinstance(m, klass):
continue
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(
model.get_submodule(name), k,
klass( **kwargs ).to(device=device, dtype=dtype)
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
return model
# trim/expand a tensor (for example, in a state dict)
def resize_weight( weight, target, dim=0, random=True ):
# trim
if target < weight.shape[dim]:
return weight[:target]
# expand
if target > weight.shape[dim]:
fn = torch.rand if random else torch.zeros
return torch.stack(
[ x for x in weight ] +
[ fn( weight[0].shape ).to(device=weight[0].device, dtype=weight[0].dtype) for _ in range( target - weight.shape[dim] ) ]
)
return weight
# grabs the memory properties of a given device
def get_device_properties( device ):
if 'cuda' in device:
props = torch.cuda.get_device_properties(device)
free, total = torch.cuda.mem_get_info(device)
else:
props = psutil.virtual_memory()
free, total = props.available, props.total
return {"name": device, "total": total, "free": free, "props": props}
# gets the rough size for a given module's parameters
def get_module_size( module ):
param_size = sum([p.nelement() * p.element_size() for p in module.parameters()])
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
return param_size + buffer_size
# assigns modules to requested devices for a given policy
def get_model_offload_policy(module, policy=None):
# handle any other weird values this is set to
if not isinstance(policy, dict):
policy = {}
# default to only include the core model, and not the other modules (embeddings) in the splitting policy
if "include" not in policy:
policy["include"] = ["model"]
if "limits" not in policy:
policy["limits"] = []
if "devices" not in policy:
policy["devices"] = [f'{"cuda"}:{i}' for i in range(torch.cuda.device_count())] + ['cpu'] # + cpu to spill the remainder on CPU if overbudget
# create initial device info
devices = [ get_device_properties(device) | {"modules": []} for device in policy["devices"] ]
modules = [ (name, get_module_size(module)) for name, module in module.named_modules() if not [*module.named_children()] and passes_policy( policy, name ) ]
# filter
modules = [ (name, size) for name, size in modules if name and size ]
total_size = sum([size for name, size in modules])
# set caps if requested in the policy
for i, cap in enumerate(policy["limits"]):
# no limit, skip
if cap <= 0:
continue
# is fractional, scale to total size
if cap < 1:
cap = math.floor(total_size * cap)
# available space is below cap, don't set
if devices[i]["free"] < cap:
continue
# cap to requested size
devices[i]["free"] = cap
device_index = 0
module_index = 0
while module_index < len(modules) and device_index < len(devices):
device = devices[device_index]
name, size = modules[module_index]
# fits within budget
if device["free"] - size >= 0:
device["modules"].append( name )
device["free"] -= size
module_index += 1
# does not fit in budget, increase device index
else:
device_index += 1
print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
# to-do: check that all modules are exhausted
assert module_index >= len(modules)
# only return devices with modules assigned
return [ device for device in devices if device["modules"] ]
# handles naively splitting a model's layers across multiple devices
# this apparently works for training too? the test trainer seemed fine with it split between GPU and CPU
def offload_model( model, policy=None ):
policy = get_model_offload_policy(model, policy=policy)
# move modules to respective devices
for i, device in enumerate( policy ):
# nothing assigned, skip
if not device["modules"]:
continue
for name in device["modules"]:
module = model.get_submodule(name)
module = module.to( device["name"] )
"""
# in case the above doesn't actually do what's requested
*parent, key = name.split(".")
module = getattr( model.get_submodule(name), key )
setattr( model.get_submodule(name), key, module.to( device["name"] ) )
"""
# select next device to cast inputs to, or wrap to first if last device
next_device = policy[i + 1]["name"] if i + 1 < len( policy ) else policy[0]["name"]
# same device, don't bother wrapping
if device["name"] == next_device:
continue
# wrap forward call
last_module = model.get_submodule( device["modules"][-1] )
last_module.forward = auto_to_forward(last_module, next_device)
"""
# Validate that the layers are all in the right spot
for name, module in model.named_modules():
if not not [*module.named_children()]:
continue
try:
print( name, next(module.parameters()).device )
except Exception as e:
pass
"""
return model