This commit is contained in:
mrq 2023-08-02 18:12:36 -05:00
parent 7a06b27a9c
commit 0f9b81de75
2 changed files with 204 additions and 201 deletions

View File

@ -48,6 +48,7 @@ setup(
"numpy>=1.23.3",
"omegaconf==2.0.6",
"tqdm>=4.64.1",
"humanize>=4.4.0",
"pandas>=1.5.0",
"torch>=1.13.0",

View File

@ -2,252 +2,254 @@
# https://github.com/enhuiz/pytorch-training-utilities
"""
# todo: replace this
import humanize
import json
import logging
import time
from typing import Any, Protocol
import numpy as np
import random
import selectors
import sys
import torch
import torch.distributed
from deepspeed import DeepSpeedEngine
from torch import Tensor
from torch.distributed import all_reduce
from .config import Config
from .distributed import fix_unset_envs
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from functools import cache
from torch.distributed import broadcast_object_list
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Protocol
Stats = dict[str, float]
from ..config import Config
from .distributed import (
global_leader_only,
global_rank,
is_global_leader,
is_local_leader,
local_leader_only,
)
from .engines import Engine, Engines, TrainFeeder
from .utils import to_device, do_gc
_logger = logging.getLogger(__name__)
_engines: Engines
_command: str
def get_global_step():
try:
return _engines.global_step
except:
return None
def get_micro_step():
try:
return _engines.micro_step
except:
return None
class Engine(DeepSpeedEngine):
def __init__(self, *args, **kwargs):
fix_unset_envs()
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def get_cfg():
try:
return _engines.cfg
except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
class TrainFeeder(Protocol):
def __call__(
self, *, engines: "Engines", batch: Any, name: str
) -> None | tuple[Tensor, Stats]:
...
def get_cmd():
try:
return _command
except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
class Engines(dict[str, Engine]):
def setup(self, cfg: Config):
self._cfg = cfg
self._global_step = 0
get_iteration = get_global_step
@property
def cfg(self) -> Config:
return self._cfg
@property
def config(self):
return self._cfg
class EnginesLoader(Protocol):
def __call__(self) -> Engines:
...
@property
def global_step(self):
return self._global_step
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def load_engines(engines: dict[str, Engine], config: Config):
engines = Engines(engines)
engines.setup(config)
if not engines.cfg.trainer.load_state_dict:
engines.load_checkpoint()
return engines
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None):
if not tag:
tag = self.cfg.trainer.save_tag
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = self.global_step
class EvalFn(Protocol):
def __call__(self, *, engines: Engines):
...
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
def load_checkpoint(self, tag=None):
if not tag:
tag = self.cfg.trainer.load_tag
class Logger(Protocol):
def __call__(self, *, data: dict):
...
for name, engine in self.items():
load_dir = self.cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=self.cfg.trainer.strict_loading,
load_optimizer_states=self.cfg.trainer.load_states,
load_lr_scheduler_states=self.cfg.trainer.load_states,
load_module_only=False, # not self.cfg.trainer.load_states,
)
if self.cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if self.cfg.hyperparameters.scheduler_type == "":
self.set_lr(self.cfg.hyperparameters.learning_rate)
@cache
def _get_stdin_selector():
selector = selectors.DefaultSelector()
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
return selector
self._update_global_step()
def set_lr(self, lr):
try:
for engine in self.values():
if hasattr(engine.optimizer, 'param_groups'):
print(engine.optimizer.param_groups)
for param_group in engine.optimizer.param_groups:
param_group['lr'] = lr
else:
engine.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
def _non_blocking_input():
global _command
l = [""]
if is_global_leader():
s = ""
selector = _get_stdin_selector()
events = selector.select(timeout=0)
for key, _ in events:
s: str = key.fileobj.readline().strip()
_logger.info(f'Get stdin "{s}".')
l[0] = s
broadcast_object_list(l, src=0)
_command = l[0]
return _command
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def eval(self):
for engine in self.values():
engine.eval()
def _make_infinite_epochs(dl):
while True:
_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
def train(self):
for engine in self.values():
engine.train()
def step(self, feeder: TrainFeeder, batch):
total_elapsed_time = 0
@local_leader_only(default=None)
def logger(data):
return _logger.info(json.dumps(data, default=str))
stats: Any = dict()
if self.cfg.trainer.gc_mode == 'step':
do_gc()
def seed(seed):
# Set up random seeds, after fork()
random.seed(seed + global_rank())
np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank())
batch = to_device(batch, torch.cuda.current_device())
for name, engine in self.items():
torch.cuda.synchronize()
if self.cfg.trainer.gc_mode == 'substep':
do_gc()
def train(
engines_loader: EnginesLoader,
train_dl: DataLoader,
train_feeder: TrainFeeder,
eval_fn: EvalFn,
logger: Logger = logger,
):
engines = engines_loader()
cfg = engines.cfg
start_time = time.time()
"""
if is_local_leader():
cfg.dump()
_logger.info(cfg)
"""
tries = 4
n_ooms = torch.zeros([], device=self.cfg.device)
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, torch.cuda.current_device())
# engine = engine.to(torch.cuda.current_device())
# Setup global engines
global _engines
_engines = engines
while tries >= 0:
try:
maybe_loss_and_engine_stats = feeder( engines=self, batch=batch, name=name )
break
except RuntimeError as e:
print("Forward", str(e))
events = []
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
eval_fn = global_leader_only(eval_fn)
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
# Pre-loop command
command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
return
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
last_save_step = engines.global_step
last_eval_step = 0
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
# Training loop
for batch in _make_infinite_epochs(train_dl):
if engines.global_step >= cfg.trainer.iterations:
break
# Here we allow skip optimizers. It's useful when, for example,
# skipping discriminators in the begining of GAN training.
if maybe_loss_and_engine_stats is None:
continue
loss, engine_stats = maybe_loss_and_engine_stats
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(feeder=train_feeder, batch=batch)
n_ooms = torch.zeros([], device=self.cfg.device)
if self.cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
stats['it'] = iteration
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
stats['batch'] = {
'size': stats['batch_size'],
'id': batch['spkr_id'],
'index': [ index for index in batch['index'] ],
'text_len': [ text.shape[0] for text in batch['text'] ],
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
del stats['batch_size']
del stats['wall_time']
del stats['global_step']
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
engine.step()
torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
command = _non_blocking_input()
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
)
}
),
)
del loss
# engine = engine.to('cpu')
if "@" in command:
what, when = command.split("@")
try:
events.append((what, int(when)))
_logger.info(f"Event {command} registered.")
except Exception as e:
_logger.error(e)
command = ""
self._update_global_step()
stats["batch_size"] = len(batch["text"])
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
events = [e for e in events if e[1] >= engines.global_step]
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
return stats
for command in commands:
if command in ["event show", "event"]:
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
_logger.info(msg)
if command == "event clear":
events.clear()
if "time" in command:
target_iter = cfg.trainer.iterations
if " to " in command:
try:
target_iter = int(command.split(" to ")[-1])
except Exception as e:
_logger.error(e)
remaining_iters = target_iter - engines.global_step + 1
remaining_time = int(remaining_iters * elapsed_time)
_logger.info(humanize.precisedelta(remaining_time))
if "lr" in command:
rate = float(command.split(" ")[-1])
engines.set_lr(rate)
print("Updating LR to:", rate)
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"]
if cfg.trainer.save_on_quit:
saving_commands.append("quit")
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
engines.save_checkpoint()
last_save_step = engines.global_step
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc()
engines.eval()
eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step
if command in ["quit"]:
return