This commit is contained in:
mrq 2023-08-02 18:12:36 -05:00
parent 7a06b27a9c
commit 0f9b81de75
2 changed files with 204 additions and 201 deletions

View File

@ -48,6 +48,7 @@ setup(
"numpy>=1.23.3", "numpy>=1.23.3",
"omegaconf==2.0.6", "omegaconf==2.0.6",
"tqdm>=4.64.1", "tqdm>=4.64.1",
"humanize>=4.4.0",
"pandas>=1.5.0", "pandas>=1.5.0",
"torch>=1.13.0", "torch>=1.13.0",

View File

@ -2,252 +2,254 @@
# https://github.com/enhuiz/pytorch-training-utilities # https://github.com/enhuiz/pytorch-training-utilities
""" """
# todo: replace this import humanize
import json
import logging import logging
import time import numpy as np
from typing import Any, Protocol import random
import selectors
import sys
import torch import torch
import torch.distributed
from deepspeed import DeepSpeedEngine
from torch import Tensor
from torch.distributed import all_reduce
from .config import Config from functools import cache
from .distributed import fix_unset_envs from torch.distributed import broadcast_object_list
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Protocol
Stats = dict[str, float] from ..config import Config
from .distributed import (
global_leader_only,
global_rank,
is_global_leader,
is_local_leader,
local_leader_only,
)
from .engines import Engine, Engines, TrainFeeder
from .utils import to_device, do_gc
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_engines: Engines
_command: str
def get_global_step():
try:
return _engines.global_step
except:
return None
def get_micro_step():
try:
return _engines.micro_step
except:
return None
class Engine(DeepSpeedEngine): def get_cfg():
def __init__(self, *args, **kwargs): try:
fix_unset_envs() return _engines.cfg
super().__init__(None, *args, **kwargs) except:
self._frozen_params = set() raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
class TrainFeeder(Protocol): def get_cmd():
def __call__( try:
self, *, engines: "Engines", batch: Any, name: str return _command
) -> None | tuple[Tensor, Stats]: except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
get_iteration = get_global_step
class EnginesLoader(Protocol):
def __call__(self) -> Engines:
... ...
class Engines(dict[str, Engine]): def load_engines(engines: dict[str, Engine], config: Config):
def setup(self, cfg: Config): engines = Engines(engines)
self._cfg = cfg engines.setup(config)
self._global_step = 0 if not engines.cfg.trainer.load_state_dict:
engines.load_checkpoint()
return engines
@property
def cfg(self) -> Config:
return self._cfg
@property class EvalFn(Protocol):
def config(self): def __call__(self, *, engines: Engines):
return self._cfg ...
@property
def global_step(self):
return self._global_step
def gather_attribute(self, *args, **kwargs): class Logger(Protocol):
ret = {} def __call__(self, *, data: dict):
for engine in self.values(): ...
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None): @cache
if not tag: def _get_stdin_selector():
tag = self.cfg.trainer.save_tag selector = selectors.DefaultSelector()
tag = tag.lower() selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
if tag[:2] == "it" or tag[:4] == "step": return selector
tag = self.global_step
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
def load_checkpoint(self, tag=None): def _non_blocking_input():
if not tag: global _command
tag = self.cfg.trainer.load_tag l = [""]
if is_global_leader():
s = ""
selector = _get_stdin_selector()
events = selector.select(timeout=0)
for key, _ in events:
s: str = key.fileobj.readline().strip()
_logger.info(f'Get stdin "{s}".')
l[0] = s
broadcast_object_list(l, src=0)
_command = l[0]
return _command
for name, engine in self.items():
load_dir = self.cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=self.cfg.trainer.strict_loading,
load_optimizer_states=self.cfg.trainer.load_states,
load_lr_scheduler_states=self.cfg.trainer.load_states,
load_module_only=False, # not self.cfg.trainer.load_states,
)
if self.cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler def _make_infinite_epochs(dl):
if self.cfg.hyperparameters.scheduler_type == "": while True:
self.set_lr(self.cfg.hyperparameters.learning_rate) _logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
self._update_global_step()
def set_lr(self, lr): @local_leader_only(default=None)
try: def logger(data):
for engine in self.values(): return _logger.info(json.dumps(data, default=str))
if hasattr(engine.optimizer, 'param_groups'):
print(engine.optimizer.param_groups)
for param_group in engine.optimizer.param_groups:
param_group['lr'] = lr
else:
engine.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def eval(self): def seed(seed):
for engine in self.values(): # Set up random seeds, after fork()
engine.eval() random.seed(seed + global_rank())
np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank())
def train(self):
for engine in self.values():
engine.train()
def step(self, feeder: TrainFeeder, batch): def train(
total_elapsed_time = 0 engines_loader: EnginesLoader,
train_dl: DataLoader,
train_feeder: TrainFeeder,
eval_fn: EvalFn,
logger: Logger = logger,
):
engines = engines_loader()
cfg = engines.cfg
stats: Any = dict() """
if is_local_leader():
cfg.dump()
_logger.info(cfg)
"""
if self.cfg.trainer.gc_mode == 'step': # Setup global engines
do_gc() global _engines
_engines = engines
batch = to_device(batch, torch.cuda.current_device()) events = []
for name, engine in self.items(): eval_fn = global_leader_only(eval_fn)
torch.cuda.synchronize()
if self.cfg.trainer.gc_mode == 'substep':
do_gc()
start_time = time.time() # Pre-loop command
command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
return
tries = 4 last_save_step = engines.global_step
n_ooms = torch.zeros([], device=self.cfg.device) last_eval_step = 0
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, torch.cuda.current_device())
# engine = engine.to(torch.cuda.current_device())
while tries >= 0: # Training loop
try: for batch in _make_infinite_epochs(train_dl):
maybe_loss_and_engine_stats = feeder( engines=self, batch=batch, name=name ) if engines.global_step >= cfg.trainer.iterations:
break break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e): #batch = to_device(batch, torch.cuda.current_device())
self.save_checkpoint() stats = engines.step(feeder=train_feeder, batch=batch)
raise e
# shrink batch size until it's happy iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
for k in batch: stats['it'] = iteration
batch[k] = batch[k][:-1] stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
if tries <= 0: stats['batch'] = {
# trigger OOM 'size': stats['batch_size'],
n_ooms += 1 'id': batch['spkr_id'],
else: 'index': [ index for index in batch['index'] ],
# also do GC 'text_len': [ text.shape[0] for text in batch['text'] ],
do_gc() 'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
continue 'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
# Here we allow skip optimizers. It's useful when, for example,
# skipping discriminators in the begining of GAN training.
if maybe_loss_and_engine_stats is None:
continue
loss, engine_stats = maybe_loss_and_engine_stats
n_ooms = torch.zeros([], device=self.cfg.device)
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
engine.step()
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
)
} }
),
)
del loss
# engine = engine.to('cpu')
self._update_global_step() del stats['batch_size']
stats["batch_size"] = len(batch["text"]) del stats['wall_time']
stats["elapsed_time"] = total_elapsed_time del stats['global_step']
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
return stats elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
command = _non_blocking_input()
if "@" in command:
what, when = command.split("@")
try:
events.append((what, int(when)))
_logger.info(f"Event {command} registered.")
except Exception as e:
_logger.error(e)
command = ""
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
events = [e for e in events if e[1] >= engines.global_step]
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
for command in commands:
if command in ["event show", "event"]:
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
_logger.info(msg)
if command == "event clear":
events.clear()
if "time" in command:
target_iter = cfg.trainer.iterations
if " to " in command:
try:
target_iter = int(command.split(" to ")[-1])
except Exception as e:
_logger.error(e)
remaining_iters = target_iter - engines.global_step + 1
remaining_time = int(remaining_iters * elapsed_time)
_logger.info(humanize.precisedelta(remaining_time))
if "lr" in command:
rate = float(command.split(" ")[-1])
engines.set_lr(rate)
print("Updating LR to:", rate)
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"]
if cfg.trainer.save_on_quit:
saving_commands.append("quit")
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
engines.save_checkpoint()
last_save_step = engines.global_step
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc()
engines.eval()
eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step
if command in ["quit"]:
return