oops
This commit is contained in:
parent
7a06b27a9c
commit
0f9b81de75
1
setup.py
1
setup.py
|
@ -48,6 +48,7 @@ setup(
|
|||
"numpy>=1.23.3",
|
||||
"omegaconf==2.0.6",
|
||||
"tqdm>=4.64.1",
|
||||
"humanize>=4.4.0",
|
||||
|
||||
"pandas>=1.5.0",
|
||||
"torch>=1.13.0",
|
||||
|
|
|
@ -2,252 +2,254 @@
|
|||
# https://github.com/enhuiz/pytorch-training-utilities
|
||||
"""
|
||||
|
||||
# todo: replace this
|
||||
|
||||
import humanize
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Protocol
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
import selectors
|
||||
import sys
|
||||
import torch
|
||||
import torch.distributed
|
||||
from deepspeed import DeepSpeedEngine
|
||||
from torch import Tensor
|
||||
from torch.distributed import all_reduce
|
||||
|
||||
from .config import Config
|
||||
from .distributed import fix_unset_envs
|
||||
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
from functools import cache
|
||||
from torch.distributed import broadcast_object_list
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from typing import Protocol
|
||||
|
||||
Stats = dict[str, float]
|
||||
from ..config import Config
|
||||
from .distributed import (
|
||||
global_leader_only,
|
||||
global_rank,
|
||||
is_global_leader,
|
||||
is_local_leader,
|
||||
local_leader_only,
|
||||
)
|
||||
|
||||
from .engines import Engine, Engines, TrainFeeder
|
||||
from .utils import to_device, do_gc
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_engines: Engines
|
||||
_command: str
|
||||
|
||||
def get_global_step():
|
||||
try:
|
||||
return _engines.global_step
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_micro_step():
|
||||
try:
|
||||
return _engines.micro_step
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
class Engine(DeepSpeedEngine):
|
||||
def __init__(self, *args, **kwargs):
|
||||
fix_unset_envs()
|
||||
super().__init__(None, *args, **kwargs)
|
||||
self._frozen_params = set()
|
||||
|
||||
def freeze(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
p.requires_grad_(False)
|
||||
self._frozen_params.add(p)
|
||||
|
||||
def unfreeze(self):
|
||||
for p in self._frozen_params:
|
||||
p.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
||||
def dispatch_attribute(self, *args, **kwargs):
|
||||
return dispatch_attribute(self.module, *args, **kwargs)
|
||||
def get_cfg():
|
||||
try:
|
||||
return _engines.cfg
|
||||
except:
|
||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||
|
||||
|
||||
class TrainFeeder(Protocol):
|
||||
def __call__(
|
||||
self, *, engines: "Engines", batch: Any, name: str
|
||||
) -> None | tuple[Tensor, Stats]:
|
||||
def get_cmd():
|
||||
try:
|
||||
return _command
|
||||
except:
|
||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||
|
||||
|
||||
get_iteration = get_global_step
|
||||
|
||||
|
||||
class EnginesLoader(Protocol):
|
||||
def __call__(self) -> Engines:
|
||||
...
|
||||
|
||||
|
||||
class Engines(dict[str, Engine]):
|
||||
def setup(self, cfg: Config):
|
||||
self._cfg = cfg
|
||||
self._global_step = 0
|
||||
def load_engines(engines: dict[str, Engine], config: Config):
|
||||
engines = Engines(engines)
|
||||
engines.setup(config)
|
||||
if not engines.cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
return engines
|
||||
|
||||
@property
|
||||
def cfg(self) -> Config:
|
||||
return self._cfg
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._cfg
|
||||
class EvalFn(Protocol):
|
||||
def __call__(self, *, engines: Engines):
|
||||
...
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self._global_step
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
ret = {}
|
||||
for engine in self.values():
|
||||
ret |= engine.gather_attribute(*args, **kwargs)
|
||||
return ret
|
||||
class Logger(Protocol):
|
||||
def __call__(self, *, data: dict):
|
||||
...
|
||||
|
||||
def dispatch_attribute(self, *args, **kwargs):
|
||||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = self.cfg.trainer.save_tag
|
||||
tag = tag.lower()
|
||||
if tag[:2] == "it" or tag[:4] == "step":
|
||||
tag = self.global_step
|
||||
@cache
|
||||
def _get_stdin_selector():
|
||||
selector = selectors.DefaultSelector()
|
||||
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
||||
return selector
|
||||
|
||||
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
for name, engine in self.items():
|
||||
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
|
||||
|
||||
def load_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
tag = self.cfg.trainer.load_tag
|
||||
def _non_blocking_input():
|
||||
global _command
|
||||
l = [""]
|
||||
if is_global_leader():
|
||||
s = ""
|
||||
selector = _get_stdin_selector()
|
||||
events = selector.select(timeout=0)
|
||||
for key, _ in events:
|
||||
s: str = key.fileobj.readline().strip()
|
||||
_logger.info(f'Get stdin "{s}".')
|
||||
l[0] = s
|
||||
broadcast_object_list(l, src=0)
|
||||
_command = l[0]
|
||||
return _command
|
||||
|
||||
for name, engine in self.items():
|
||||
load_dir = self.cfg.ckpt_dir / name
|
||||
engine.load_checkpoint(
|
||||
tag=tag,
|
||||
load_dir=load_dir,
|
||||
load_module_strict=self.cfg.trainer.strict_loading,
|
||||
load_optimizer_states=self.cfg.trainer.load_states,
|
||||
load_lr_scheduler_states=self.cfg.trainer.load_states,
|
||||
load_module_only=False, # not self.cfg.trainer.load_states,
|
||||
)
|
||||
if self.cfg.trainer.restart_step_count:
|
||||
engine.global_steps = 0
|
||||
|
||||
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
|
||||
if self.cfg.hyperparameters.scheduler_type == "":
|
||||
self.set_lr(self.cfg.hyperparameters.learning_rate)
|
||||
def _make_infinite_epochs(dl):
|
||||
while True:
|
||||
_logger.info("New epoch starts.")
|
||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
||||
|
||||
self._update_global_step()
|
||||
|
||||
def set_lr(self, lr):
|
||||
try:
|
||||
for engine in self.values():
|
||||
if hasattr(engine.optimizer, 'param_groups'):
|
||||
print(engine.optimizer.param_groups)
|
||||
for param_group in engine.optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
else:
|
||||
engine.optimizer.set_lr(lr)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
@local_leader_only(default=None)
|
||||
def logger(data):
|
||||
return _logger.info(json.dumps(data, default=str))
|
||||
|
||||
def _update_global_step(self):
|
||||
for engine in self.values():
|
||||
self._global_step = max(self._global_step, engine.global_step)
|
||||
|
||||
def eval(self):
|
||||
for engine in self.values():
|
||||
engine.eval()
|
||||
def seed(seed):
|
||||
# Set up random seeds, after fork()
|
||||
random.seed(seed + global_rank())
|
||||
np.random.seed(seed + global_rank())
|
||||
torch.manual_seed(seed + global_rank())
|
||||
|
||||
def train(self):
|
||||
for engine in self.values():
|
||||
engine.train()
|
||||
|
||||
def step(self, feeder: TrainFeeder, batch):
|
||||
total_elapsed_time = 0
|
||||
def train(
|
||||
engines_loader: EnginesLoader,
|
||||
train_dl: DataLoader,
|
||||
train_feeder: TrainFeeder,
|
||||
eval_fn: EvalFn,
|
||||
logger: Logger = logger,
|
||||
):
|
||||
engines = engines_loader()
|
||||
cfg = engines.cfg
|
||||
|
||||
stats: Any = dict()
|
||||
"""
|
||||
if is_local_leader():
|
||||
cfg.dump()
|
||||
_logger.info(cfg)
|
||||
"""
|
||||
|
||||
if self.cfg.trainer.gc_mode == 'step':
|
||||
do_gc()
|
||||
# Setup global engines
|
||||
global _engines
|
||||
_engines = engines
|
||||
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
events = []
|
||||
|
||||
for name, engine in self.items():
|
||||
torch.cuda.synchronize()
|
||||
if self.cfg.trainer.gc_mode == 'substep':
|
||||
do_gc()
|
||||
eval_fn = global_leader_only(eval_fn)
|
||||
|
||||
start_time = time.time()
|
||||
# Pre-loop command
|
||||
command = _non_blocking_input()
|
||||
if command in ["eval", "eval_quit"]:
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
if command in ["quit", "eval_quit"]:
|
||||
return
|
||||
|
||||
tries = 4
|
||||
n_ooms = torch.zeros([], device=self.cfg.device)
|
||||
if self.cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
# engine = engine.to(torch.cuda.current_device())
|
||||
last_save_step = engines.global_step
|
||||
last_eval_step = 0
|
||||
|
||||
while tries >= 0:
|
||||
try:
|
||||
maybe_loss_and_engine_stats = feeder( engines=self, batch=batch, name=name )
|
||||
# Training loop
|
||||
for batch in _make_infinite_epochs(train_dl):
|
||||
if engines.global_step >= cfg.trainer.iterations:
|
||||
break
|
||||
except RuntimeError as e:
|
||||
print("Forward", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
#batch = to_device(batch, torch.cuda.current_device())
|
||||
stats = engines.step(feeder=train_feeder, batch=batch)
|
||||
|
||||
# shrink batch size until it's happy
|
||||
for k in batch:
|
||||
batch[k] = batch[k][:-1]
|
||||
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
|
||||
stats['it'] = iteration
|
||||
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||
|
||||
if tries <= 0:
|
||||
# trigger OOM
|
||||
n_ooms += 1
|
||||
else:
|
||||
# also do GC
|
||||
do_gc()
|
||||
continue
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during forward pass!")
|
||||
|
||||
# Here we allow skip optimizers. It's useful when, for example,
|
||||
# skipping discriminators in the begining of GAN training.
|
||||
if maybe_loss_and_engine_stats is None:
|
||||
continue
|
||||
|
||||
loss, engine_stats = maybe_loss_and_engine_stats
|
||||
|
||||
n_ooms = torch.zeros([], device=self.cfg.device)
|
||||
|
||||
if self.cfg.trainer.aggressive_optimizations:
|
||||
batch = to_device(batch, 'cpu')
|
||||
|
||||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
print("Backwards:", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
n_ooms += 1
|
||||
|
||||
all_reduce(n_ooms)
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
|
||||
engine.step()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = time.time() - start_time
|
||||
total_elapsed_time += elapsed_time
|
||||
|
||||
stats.update(
|
||||
flatten_dict(
|
||||
{
|
||||
name.split("-")[0]: dict(
|
||||
loss=loss.item(),
|
||||
lr=engine.get_lr()[0],
|
||||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
**engine_stats,
|
||||
)
|
||||
stats['batch'] = {
|
||||
'size': stats['batch_size'],
|
||||
'id': batch['spkr_id'],
|
||||
'index': [ index for index in batch['index'] ],
|
||||
'text_len': [ text.shape[0] for text in batch['text'] ],
|
||||
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||
}
|
||||
),
|
||||
)
|
||||
del loss
|
||||
# engine = engine.to('cpu')
|
||||
|
||||
self._update_global_step()
|
||||
stats["batch_size"] = len(batch["text"])
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
stats["wall_time"] = time.time()
|
||||
stats["global_step"] = self.global_step
|
||||
del stats['batch_size']
|
||||
del stats['wall_time']
|
||||
del stats['global_step']
|
||||
|
||||
return stats
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||
|
||||
command = _non_blocking_input()
|
||||
|
||||
if "@" in command:
|
||||
what, when = command.split("@")
|
||||
try:
|
||||
events.append((what, int(when)))
|
||||
_logger.info(f"Event {command} registered.")
|
||||
except Exception as e:
|
||||
_logger.error(e)
|
||||
command = ""
|
||||
|
||||
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
|
||||
events = [e for e in events if e[1] >= engines.global_step]
|
||||
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
|
||||
|
||||
for command in commands:
|
||||
if command in ["event show", "event"]:
|
||||
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
|
||||
_logger.info(msg)
|
||||
|
||||
if command == "event clear":
|
||||
events.clear()
|
||||
|
||||
if "time" in command:
|
||||
target_iter = cfg.trainer.iterations
|
||||
if " to " in command:
|
||||
try:
|
||||
target_iter = int(command.split(" to ")[-1])
|
||||
except Exception as e:
|
||||
_logger.error(e)
|
||||
remaining_iters = target_iter - engines.global_step + 1
|
||||
remaining_time = int(remaining_iters * elapsed_time)
|
||||
_logger.info(humanize.precisedelta(remaining_time))
|
||||
|
||||
if "lr" in command:
|
||||
rate = float(command.split(" ")[-1])
|
||||
engines.set_lr(rate)
|
||||
print("Updating LR to:", rate)
|
||||
|
||||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||
|
||||
saving_commands = ["save"]
|
||||
|
||||
if cfg.trainer.save_on_quit:
|
||||
saving_commands.append("quit")
|
||||
|
||||
if engines.global_step != last_save_step:
|
||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
if engines.global_step != last_eval_step:
|
||||
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||
do_gc()
|
||||
|
||||
engines.eval()
|
||||
eval_fn(engines=engines)
|
||||
engines.train()
|
||||
last_eval_step = engines.global_step
|
||||
|
||||
if command in ["quit"]:
|
||||
return
|
Loading…
Reference in New Issue
Block a user