vall-e/vall_e/utils/utils.py

178 lines
4.3 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
from .distributed import global_rank, local_rank, global_leader_only
import gc
import logging
import pandas as pd
import numpy as np
2023-08-02 21:53:35 +00:00
import re
import torch
import random
import time
2023-08-02 21:53:35 +00:00
from coloredlogs import ColoredFormatter
from logging import StreamHandler
from pathlib import Path
from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
T = TypeVar("T")
def truncate_json( str ):
def fun( match ):
return "{:.4f}".format(float(match.group()))
return re.sub(r"\d+\.\d{8,}", fun, str)
2023-08-02 21:53:35 +00:00
def do_gc():
gc.collect()
torch.cuda.empty_cache()
def flatten_dict(d):
records = pd.json_normalize(d).to_dict(orient="records")
return records[0] if records else {}
def set_seed(seed=None):
if not seed:
seed = int(time.time())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
2023-08-02 21:53:35 +00:00
def _get_named_modules(module, attrname):
for name, module in module.named_modules():
if hasattr(module, attrname):
yield name, module
def gather_attribute(module, attrname, delete=True, prefix=True):
ret = {}
for name, module in _get_named_modules(module, attrname):
ret[name] = getattr(module, attrname)
if delete:
try:
delattr(module, attrname)
except Exception as e:
raise RuntimeError(f"{name} {module} {attrname}") from e
if prefix:
ret = {attrname: ret}
ret = flatten_dict(ret)
# remove consecutive dots
ret = {re.sub(r"\.+", ".", k): v for k, v in ret.items()}
return ret
def dispatch_attribute(
module,
attrname,
value,
filter_fn: Callable[[nn.Module], bool] | None = None,
):
for _, module in _get_named_modules(module, attrname):
if filter_fn is None or filter_fn(module):
setattr(module, attrname, value)
def load_state_dict_non_strict(model, state_dict, logger=None):
model_state_dict = model.state_dict()
provided = set(state_dict)
required = set(model_state_dict)
agreed = provided & required
for k in list(agreed):
if model_state_dict[k].shape != state_dict[k].shape:
agreed.remove(k)
provided.remove(k)
state_dict = {k: state_dict[k] for k in agreed}
if logger is not None and (diff := provided - required):
logger.warning(
f"Extra parameters are found. "
f"Provided but not required parameters: \n{diff}."
)
if logger is not None and (diff := required - provided):
logger.warning(
f"Some parameters are missing. "
f"Required but not provided parameters: \n{diff}."
)
model.load_state_dict(state_dict, strict=False)
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
2023-08-02 21:53:35 +00:00
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception as e:
self.handleError(record)
2023-08-02 21:53:35 +00:00
@global_leader_only
def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
handlers = []
#stdout_handler = StreamHandler()
stdout_handler = TqdmLoggingHandler()
stdout_handler.setLevel(logging.INFO)
formatter = ColoredFormatter(
f"%(asctime)s - %(name)s - %(levelname)s - GR={global_rank()};LR={local_rank()} - \n%(message)s"
)
stdout_handler.setFormatter(formatter)
handlers.append(stdout_handler)
if log_dir is not None:
filename = Path(log_dir) / f"log.txt"
filename.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(filename, mode="a")
file_handler.setLevel(logging.DEBUG)
handlers.append(file_handler)
2023-08-02 21:53:35 +00:00
logging.basicConfig(
level=logging.getLevelName(log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - \n%(message)s",
handlers=handlers,
)
@overload
def tree_map(fn: Callable, x: list[T]) -> list[T]:
...
@overload
def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
...
@overload
def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
...
@overload
def tree_map(fn: Callable, x: T) -> T:
...
def tree_map(fn: Callable, x):
if isinstance(x, list):
x = [tree_map(fn, xi) for xi in x]
elif isinstance(x, tuple):
x = (tree_map(fn, xi) for xi in x)
elif isinstance(x, dict):
x = {k: tree_map(fn, v) for k, v in x.items()}
elif isinstance(x, Tensor):
x = fn(x)
return x
def to_device(x: T, device) -> T:
return tree_map(lambda t: t.to(device), x)