diff --git a/setup.py b/setup.py index 203be90..6a514a2 100755 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ setup( "numpy>=1.23.3", "omegaconf==2.0.6", "tqdm>=4.64.1", + "humanize>=4.4.0", "pandas>=1.5.0", "torch>=1.13.0", diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 8e9f86c..de47763 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -2,252 +2,254 @@ # https://github.com/enhuiz/pytorch-training-utilities """ -# todo: replace this - +import humanize +import json import logging -import time -from typing import Any, Protocol - +import numpy as np +import random +import selectors +import sys import torch -import torch.distributed -from deepspeed import DeepSpeedEngine -from torch import Tensor -from torch.distributed import all_reduce -from .config import Config -from .distributed import fix_unset_envs -from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device +from functools import cache +from torch.distributed import broadcast_object_list +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing import Protocol -Stats = dict[str, float] +from ..config import Config +from .distributed import ( + global_leader_only, + global_rank, + is_global_leader, + is_local_leader, + local_leader_only, +) + +from .engines import Engine, Engines, TrainFeeder +from .utils import to_device, do_gc _logger = logging.getLogger(__name__) +_engines: Engines +_command: str + +def get_global_step(): + try: + return _engines.global_step + except: + return None + +def get_micro_step(): + try: + return _engines.micro_step + except: + return None -class Engine(DeepSpeedEngine): - def __init__(self, *args, **kwargs): - fix_unset_envs() - super().__init__(None, *args, **kwargs) - self._frozen_params = set() - - def freeze(self): - for p in self.module.parameters(): - if p.requires_grad: - p.requires_grad_(False) - self._frozen_params.add(p) - - def unfreeze(self): - for p in self._frozen_params: - p.requires_grad_(True) - self._frozen_params.clear() - - @property - def global_step(self): - return self.global_steps - - def gather_attribute(self, *args, **kwargs): - return gather_attribute(self.module, *args, **kwargs) - - def dispatch_attribute(self, *args, **kwargs): - return dispatch_attribute(self.module, *args, **kwargs) +def get_cfg(): + try: + return _engines.cfg + except: + raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") -class TrainFeeder(Protocol): - def __call__( - self, *, engines: "Engines", batch: Any, name: str - ) -> None | tuple[Tensor, Stats]: - ... +def get_cmd(): + try: + return _command + except: + raise RuntimeError("Trainer has not been setup. Have you called trainer.train?") -class Engines(dict[str, Engine]): - def setup(self, cfg: Config): - self._cfg = cfg - self._global_step = 0 +get_iteration = get_global_step - @property - def cfg(self) -> Config: - return self._cfg - @property - def config(self): - return self._cfg +class EnginesLoader(Protocol): + def __call__(self) -> Engines: + ... - @property - def global_step(self): - return self._global_step - def gather_attribute(self, *args, **kwargs): - ret = {} - for engine in self.values(): - ret |= engine.gather_attribute(*args, **kwargs) - return ret +def load_engines(engines: dict[str, Engine], config: Config): + engines = Engines(engines) + engines.setup(config) + if not engines.cfg.trainer.load_state_dict: + engines.load_checkpoint() + return engines - def dispatch_attribute(self, *args, **kwargs): - for engine in self.values(): - engine.dispatch_attribute(*args, **kwargs) - def save_checkpoint(self, tag=None): - if not tag: - tag = self.cfg.trainer.save_tag - tag = tag.lower() - if tag[:2] == "it" or tag[:4] == "step": - tag = self.global_step +class EvalFn(Protocol): + def __call__(self, *, engines: Engines): + ... - self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) - for name, engine in self.items(): - engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag) - def load_checkpoint(self, tag=None): - if not tag: - tag = self.cfg.trainer.load_tag +class Logger(Protocol): + def __call__(self, *, data: dict): + ... - for name, engine in self.items(): - load_dir = self.cfg.ckpt_dir / name - engine.load_checkpoint( - tag=tag, - load_dir=load_dir, - load_module_strict=self.cfg.trainer.strict_loading, - load_optimizer_states=self.cfg.trainer.load_states, - load_lr_scheduler_states=self.cfg.trainer.load_states, - load_module_only=False, # not self.cfg.trainer.load_states, - ) - if self.cfg.trainer.restart_step_count: - engine.global_steps = 0 - # update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler - if self.cfg.hyperparameters.scheduler_type == "": - self.set_lr(self.cfg.hyperparameters.learning_rate) +@cache +def _get_stdin_selector(): + selector = selectors.DefaultSelector() + selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ) + return selector - self._update_global_step() - def set_lr(self, lr): - try: - for engine in self.values(): - if hasattr(engine.optimizer, 'param_groups'): - print(engine.optimizer.param_groups) - for param_group in engine.optimizer.param_groups: - param_group['lr'] = lr - else: - engine.optimizer.set_lr(lr) - except Exception as e: - print(str(e)) +def _non_blocking_input(): + global _command + l = [""] + if is_global_leader(): + s = "" + selector = _get_stdin_selector() + events = selector.select(timeout=0) + for key, _ in events: + s: str = key.fileobj.readline().strip() + _logger.info(f'Get stdin "{s}".') + l[0] = s + broadcast_object_list(l, src=0) + _command = l[0] + return _command - def _update_global_step(self): - for engine in self.values(): - self._global_step = max(self._global_step, engine.global_step) - def eval(self): - for engine in self.values(): - engine.eval() +def _make_infinite_epochs(dl): + while True: + _logger.info("New epoch starts.") + yield from tqdm(dl, "Epoch progress", dynamic_ncols=True) - def train(self): - for engine in self.values(): - engine.train() - def step(self, feeder: TrainFeeder, batch): - total_elapsed_time = 0 +@local_leader_only(default=None) +def logger(data): + return _logger.info(json.dumps(data, default=str)) - stats: Any = dict() - if self.cfg.trainer.gc_mode == 'step': - do_gc() +def seed(seed): + # Set up random seeds, after fork() + random.seed(seed + global_rank()) + np.random.seed(seed + global_rank()) + torch.manual_seed(seed + global_rank()) - batch = to_device(batch, torch.cuda.current_device()) - for name, engine in self.items(): - torch.cuda.synchronize() - if self.cfg.trainer.gc_mode == 'substep': - do_gc() +def train( + engines_loader: EnginesLoader, + train_dl: DataLoader, + train_feeder: TrainFeeder, + eval_fn: EvalFn, + logger: Logger = logger, +): + engines = engines_loader() + cfg = engines.cfg - start_time = time.time() + """ + if is_local_leader(): + cfg.dump() + _logger.info(cfg) + """ - tries = 4 - n_ooms = torch.zeros([], device=self.cfg.device) - if self.cfg.trainer.aggressive_optimizations: - batch = to_device(batch, torch.cuda.current_device()) - # engine = engine.to(torch.cuda.current_device()) + # Setup global engines + global _engines + _engines = engines - while tries >= 0: - try: - maybe_loss_and_engine_stats = feeder( engines=self, batch=batch, name=name ) - break - except RuntimeError as e: - print("Forward", str(e)) + events = [] - if "out of memory" not in str(e): - self.save_checkpoint() - raise e + eval_fn = global_leader_only(eval_fn) - # shrink batch size until it's happy - for k in batch: - batch[k] = batch[k][:-1] + # Pre-loop command + command = _non_blocking_input() + if command in ["eval", "eval_quit"]: + engines.eval() + eval_fn(engines=engines) + engines.train() + if command in ["quit", "eval_quit"]: + return - if tries <= 0: - # trigger OOM - n_ooms += 1 - else: - # also do GC - do_gc() - continue + last_save_step = engines.global_step + last_eval_step = 0 - all_reduce(n_ooms) - if n_ooms.item() > 0: - self.save_checkpoint() - raise RuntimeError("Out of memory during forward pass!") + # Training loop + for batch in _make_infinite_epochs(train_dl): + if engines.global_step >= cfg.trainer.iterations: + break - # Here we allow skip optimizers. It's useful when, for example, - # skipping discriminators in the begining of GAN training. - if maybe_loss_and_engine_stats is None: - continue - - loss, engine_stats = maybe_loss_and_engine_stats + #batch = to_device(batch, torch.cuda.current_device()) + stats = engines.step(feeder=train_feeder, batch=batch) - n_ooms = torch.zeros([], device=self.cfg.device) - - if self.cfg.trainer.aggressive_optimizations: - batch = to_device(batch, 'cpu') + iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps + stats['it'] = iteration + stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl) - try: - engine.backward(loss) - except RuntimeError as e: - print("Backwards:", str(e)) + stats['batch'] = { + 'size': stats['batch_size'], + 'id': batch['spkr_id'], + 'index': [ index for index in batch['index'] ], + 'text_len': [ text.shape[0] for text in batch['text'] ], + 'prom_len': [ prom.shape[0] for prom in batch['proms'] ], + 'resp_len': [ resp.shape[0] for resp in batch['resps'] ], + } - if "out of memory" not in str(e): - self.save_checkpoint() - raise e - - n_ooms += 1 + del stats['batch_size'] + del stats['wall_time'] + del stats['global_step'] - all_reduce(n_ooms) - if n_ooms.item() > 0: - self.save_checkpoint() - raise RuntimeError("Out of memory during backwards pass!") + elapsed_time = stats.get("elapsed_time", 0) + _logger.info(f"Training Metrics: {json.dumps(stats)}.") - engine.step() - torch.cuda.synchronize() - elapsed_time = time.time() - start_time - total_elapsed_time += elapsed_time + command = _non_blocking_input() - stats.update( - flatten_dict( - { - name.split("-")[0]: dict( - loss=loss.item(), - lr=engine.get_lr()[0], - grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation - elapsed_time=elapsed_time, - engine_step=engine.global_step, - **engine_stats, - ) - } - ), - ) - del loss - # engine = engine.to('cpu') + if "@" in command: + what, when = command.split("@") + try: + events.append((what, int(when))) + _logger.info(f"Event {command} registered.") + except Exception as e: + _logger.error(e) + command = "" - self._update_global_step() - stats["batch_size"] = len(batch["text"]) - stats["elapsed_time"] = total_elapsed_time - stats["wall_time"] = time.time() - stats["global_step"] = self.global_step + # Commands are the current command plus the triggered (i.e. iteration >= trigger point) events + events = [e for e in events if e[1] >= engines.global_step] + commands = [command] + [e[0] for e in events if e[1] == engines.global_step] - return stats + for command in commands: + if command in ["event show", "event"]: + msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events]) + _logger.info(msg) + + if command == "event clear": + events.clear() + + if "time" in command: + target_iter = cfg.trainer.iterations + if " to " in command: + try: + target_iter = int(command.split(" to ")[-1]) + except Exception as e: + _logger.error(e) + remaining_iters = target_iter - engines.global_step + 1 + remaining_time = int(remaining_iters * elapsed_time) + _logger.info(humanize.precisedelta(remaining_time)) + + if "lr" in command: + rate = float(command.split(" ")[-1]) + engines.set_lr(rate) + print("Updating LR to:", rate) + + save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency + + saving_commands = ["save"] + + if cfg.trainer.save_on_quit: + saving_commands.append("quit") + + if engines.global_step != last_save_step: + if engines.global_step % save_ckpt_every == 0 or command in saving_commands: + engines.save_checkpoint() + last_save_step = engines.global_step + + if engines.global_step != last_eval_step: + if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: + do_gc() + + engines.eval() + eval_fn(engines=engines) + engines.train() + last_eval_step = engines.global_step + + if command in ["quit"]: + return \ No newline at end of file