This commit is contained in:
mrq 2023-08-02 18:12:36 -05:00
parent 7a06b27a9c
commit 0f9b81de75
2 changed files with 204 additions and 201 deletions

View File

@ -48,6 +48,7 @@ setup(
"numpy>=1.23.3", "numpy>=1.23.3",
"omegaconf==2.0.6", "omegaconf==2.0.6",
"tqdm>=4.64.1", "tqdm>=4.64.1",
"humanize>=4.4.0",
"pandas>=1.5.0", "pandas>=1.5.0",
"torch>=1.13.0", "torch>=1.13.0",

View File

@ -2,252 +2,254 @@
# https://github.com/enhuiz/pytorch-training-utilities # https://github.com/enhuiz/pytorch-training-utilities
""" """
# todo: replace this import humanize
import json
import logging import logging
import time import numpy as np
from typing import Any, Protocol import random
import selectors
import sys
import torch import torch
import torch.distributed
from deepspeed import DeepSpeedEngine
from torch import Tensor
from torch.distributed import all_reduce
from .config import Config from functools import cache
from .distributed import fix_unset_envs from torch.distributed import broadcast_object_list
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Protocol
Stats = dict[str, float] from ..config import Config
from .distributed import (
global_leader_only,
global_rank,
is_global_leader,
is_local_leader,
local_leader_only,
)
from .engines import Engine, Engines, TrainFeeder
from .utils import to_device, do_gc
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_engines: Engines
_command: str
def get_global_step():
try:
return _engines.global_step
except:
return None
def get_micro_step():
try:
return _engines.micro_step
except:
return None
class Engine(DeepSpeedEngine): def get_cfg():
def __init__(self, *args, **kwargs): try:
fix_unset_envs() return _engines.cfg
super().__init__(None, *args, **kwargs) except:
self._frozen_params = set() raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
class TrainFeeder(Protocol): def get_cmd():
def __call__( try:
self, *, engines: "Engines", batch: Any, name: str return _command
) -> None | tuple[Tensor, Stats]: except:
... raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
class Engines(dict[str, Engine]): get_iteration = get_global_step
def setup(self, cfg: Config):
self._cfg = cfg
self._global_step = 0
@property
def cfg(self) -> Config:
return self._cfg
@property class EnginesLoader(Protocol):
def config(self): def __call__(self) -> Engines:
return self._cfg ...
@property
def global_step(self):
return self._global_step
def gather_attribute(self, *args, **kwargs): def load_engines(engines: dict[str, Engine], config: Config):
ret = {} engines = Engines(engines)
for engine in self.values(): engines.setup(config)
ret |= engine.gather_attribute(*args, **kwargs) if not engines.cfg.trainer.load_state_dict:
return ret engines.load_checkpoint()
return engines
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None): class EvalFn(Protocol):
if not tag: def __call__(self, *, engines: Engines):
tag = self.cfg.trainer.save_tag ...
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = self.global_step
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
def load_checkpoint(self, tag=None): class Logger(Protocol):
if not tag: def __call__(self, *, data: dict):
tag = self.cfg.trainer.load_tag ...
for name, engine in self.items():
load_dir = self.cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=self.cfg.trainer.strict_loading,
load_optimizer_states=self.cfg.trainer.load_states,
load_lr_scheduler_states=self.cfg.trainer.load_states,
load_module_only=False, # not self.cfg.trainer.load_states,
)
if self.cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler @cache
if self.cfg.hyperparameters.scheduler_type == "": def _get_stdin_selector():
self.set_lr(self.cfg.hyperparameters.learning_rate) selector = selectors.DefaultSelector()
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
return selector
self._update_global_step()
def set_lr(self, lr): def _non_blocking_input():
try: global _command
for engine in self.values(): l = [""]
if hasattr(engine.optimizer, 'param_groups'): if is_global_leader():
print(engine.optimizer.param_groups) s = ""
for param_group in engine.optimizer.param_groups: selector = _get_stdin_selector()
param_group['lr'] = lr events = selector.select(timeout=0)
else: for key, _ in events:
engine.optimizer.set_lr(lr) s: str = key.fileobj.readline().strip()
except Exception as e: _logger.info(f'Get stdin "{s}".')
print(str(e)) l[0] = s
broadcast_object_list(l, src=0)
_command = l[0]
return _command
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def eval(self): def _make_infinite_epochs(dl):
for engine in self.values(): while True:
engine.eval() _logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
def train(self):
for engine in self.values():
engine.train()
def step(self, feeder: TrainFeeder, batch): @local_leader_only(default=None)
total_elapsed_time = 0 def logger(data):
return _logger.info(json.dumps(data, default=str))
stats: Any = dict()
if self.cfg.trainer.gc_mode == 'step': def seed(seed):
do_gc() # Set up random seeds, after fork()
random.seed(seed + global_rank())
np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank())
batch = to_device(batch, torch.cuda.current_device())
for name, engine in self.items(): def train(
torch.cuda.synchronize() engines_loader: EnginesLoader,
if self.cfg.trainer.gc_mode == 'substep': train_dl: DataLoader,
do_gc() train_feeder: TrainFeeder,
eval_fn: EvalFn,
logger: Logger = logger,
):
engines = engines_loader()
cfg = engines.cfg
start_time = time.time() """
if is_local_leader():
cfg.dump()
_logger.info(cfg)
"""
tries = 4 # Setup global engines
n_ooms = torch.zeros([], device=self.cfg.device) global _engines
if self.cfg.trainer.aggressive_optimizations: _engines = engines
batch = to_device(batch, torch.cuda.current_device())
# engine = engine.to(torch.cuda.current_device())
while tries >= 0: events = []
try:
maybe_loss_and_engine_stats = feeder( engines=self, batch=batch, name=name )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e): eval_fn = global_leader_only(eval_fn)
self.save_checkpoint()
raise e
# shrink batch size until it's happy # Pre-loop command
for k in batch: command = _non_blocking_input()
batch[k] = batch[k][:-1] if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
return
if tries <= 0: last_save_step = engines.global_step
# trigger OOM last_eval_step = 0
n_ooms += 1
else:
# also do GC
do_gc()
continue
all_reduce(n_ooms) # Training loop
if n_ooms.item() > 0: for batch in _make_infinite_epochs(train_dl):
self.save_checkpoint() if engines.global_step >= cfg.trainer.iterations:
raise RuntimeError("Out of memory during forward pass!") break
# Here we allow skip optimizers. It's useful when, for example, #batch = to_device(batch, torch.cuda.current_device())
# skipping discriminators in the begining of GAN training. stats = engines.step(feeder=train_feeder, batch=batch)
if maybe_loss_and_engine_stats is None:
continue
loss, engine_stats = maybe_loss_and_engine_stats
n_ooms = torch.zeros([], device=self.cfg.device) iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
stats['it'] = iteration
if self.cfg.trainer.aggressive_optimizations: stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
batch = to_device(batch, 'cpu')
try: stats['batch'] = {
engine.backward(loss) 'size': stats['batch_size'],
except RuntimeError as e: 'id': batch['spkr_id'],
print("Backwards:", str(e)) 'index': [ index for index in batch['index'] ],
'text_len': [ text.shape[0] for text in batch['text'] ],
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
if "out of memory" not in str(e): del stats['batch_size']
self.save_checkpoint() del stats['wall_time']
raise e del stats['global_step']
n_ooms += 1
all_reduce(n_ooms) elapsed_time = stats.get("elapsed_time", 0)
if n_ooms.item() > 0: _logger.info(f"Training Metrics: {json.dumps(stats)}.")
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
engine.step() command = _non_blocking_input()
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update( if "@" in command:
flatten_dict( what, when = command.split("@")
{ try:
name.split("-")[0]: dict( events.append((what, int(when)))
loss=loss.item(), _logger.info(f"Event {command} registered.")
lr=engine.get_lr()[0], except Exception as e:
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation _logger.error(e)
elapsed_time=elapsed_time, command = ""
engine_step=engine.global_step,
**engine_stats,
)
}
),
)
del loss
# engine = engine.to('cpu')
self._update_global_step() # Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
stats["batch_size"] = len(batch["text"]) events = [e for e in events if e[1] >= engines.global_step]
stats["elapsed_time"] = total_elapsed_time commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
return stats for command in commands:
if command in ["event show", "event"]:
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
_logger.info(msg)
if command == "event clear":
events.clear()
if "time" in command:
target_iter = cfg.trainer.iterations
if " to " in command:
try:
target_iter = int(command.split(" to ")[-1])
except Exception as e:
_logger.error(e)
remaining_iters = target_iter - engines.global_step + 1
remaining_time = int(remaining_iters * elapsed_time)
_logger.info(humanize.precisedelta(remaining_time))
if "lr" in command:
rate = float(command.split(" ")[-1])
engines.set_lr(rate)
print("Updating LR to:", rate)
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"]
if cfg.trainer.save_on_quit:
saving_commands.append("quit")
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
engines.save_checkpoint()
last_save_step = engines.global_step
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc()
engines.eval()
eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step
if command in ["quit"]:
return