oops
This commit is contained in:
parent
7a06b27a9c
commit
0f9b81de75
1
setup.py
1
setup.py
|
@ -48,6 +48,7 @@ setup(
|
||||||
"numpy>=1.23.3",
|
"numpy>=1.23.3",
|
||||||
"omegaconf==2.0.6",
|
"omegaconf==2.0.6",
|
||||||
"tqdm>=4.64.1",
|
"tqdm>=4.64.1",
|
||||||
|
"humanize>=4.4.0",
|
||||||
|
|
||||||
"pandas>=1.5.0",
|
"pandas>=1.5.0",
|
||||||
"torch>=1.13.0",
|
"torch>=1.13.0",
|
||||||
|
|
|
@ -2,252 +2,254 @@
|
||||||
# https://github.com/enhuiz/pytorch-training-utilities
|
# https://github.com/enhuiz/pytorch-training-utilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# todo: replace this
|
import humanize
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import numpy as np
|
||||||
from typing import Any, Protocol
|
import random
|
||||||
|
import selectors
|
||||||
|
import sys
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
from deepspeed import DeepSpeedEngine
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.distributed import all_reduce
|
|
||||||
|
|
||||||
from .config import Config
|
from functools import cache
|
||||||
from .distributed import fix_unset_envs
|
from torch.distributed import broadcast_object_list
|
||||||
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
Stats = dict[str, float]
|
from ..config import Config
|
||||||
|
from .distributed import (
|
||||||
|
global_leader_only,
|
||||||
|
global_rank,
|
||||||
|
is_global_leader,
|
||||||
|
is_local_leader,
|
||||||
|
local_leader_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .engines import Engine, Engines, TrainFeeder
|
||||||
|
from .utils import to_device, do_gc
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
_engines: Engines
|
||||||
|
_command: str
|
||||||
|
|
||||||
|
def get_global_step():
|
||||||
|
try:
|
||||||
|
return _engines.global_step
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_micro_step():
|
||||||
|
try:
|
||||||
|
return _engines.micro_step
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Engine(DeepSpeedEngine):
|
def get_cfg():
|
||||||
def __init__(self, *args, **kwargs):
|
try:
|
||||||
fix_unset_envs()
|
return _engines.cfg
|
||||||
super().__init__(None, *args, **kwargs)
|
except:
|
||||||
self._frozen_params = set()
|
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
for p in self.module.parameters():
|
|
||||||
if p.requires_grad:
|
|
||||||
p.requires_grad_(False)
|
|
||||||
self._frozen_params.add(p)
|
|
||||||
|
|
||||||
def unfreeze(self):
|
|
||||||
for p in self._frozen_params:
|
|
||||||
p.requires_grad_(True)
|
|
||||||
self._frozen_params.clear()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def global_step(self):
|
|
||||||
return self.global_steps
|
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
|
||||||
return gather_attribute(self.module, *args, **kwargs)
|
|
||||||
|
|
||||||
def dispatch_attribute(self, *args, **kwargs):
|
|
||||||
return dispatch_attribute(self.module, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainFeeder(Protocol):
|
def get_cmd():
|
||||||
def __call__(
|
try:
|
||||||
self, *, engines: "Engines", batch: Any, name: str
|
return _command
|
||||||
) -> None | tuple[Tensor, Stats]:
|
except:
|
||||||
...
|
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||||
|
|
||||||
|
|
||||||
class Engines(dict[str, Engine]):
|
get_iteration = get_global_step
|
||||||
def setup(self, cfg: Config):
|
|
||||||
self._cfg = cfg
|
|
||||||
self._global_step = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cfg(self) -> Config:
|
|
||||||
return self._cfg
|
|
||||||
|
|
||||||
@property
|
class EnginesLoader(Protocol):
|
||||||
def config(self):
|
def __call__(self) -> Engines:
|
||||||
return self._cfg
|
...
|
||||||
|
|
||||||
@property
|
|
||||||
def global_step(self):
|
|
||||||
return self._global_step
|
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
def load_engines(engines: dict[str, Engine], config: Config):
|
||||||
ret = {}
|
engines = Engines(engines)
|
||||||
for engine in self.values():
|
engines.setup(config)
|
||||||
ret |= engine.gather_attribute(*args, **kwargs)
|
if not engines.cfg.trainer.load_state_dict:
|
||||||
return ret
|
engines.load_checkpoint()
|
||||||
|
return engines
|
||||||
|
|
||||||
def dispatch_attribute(self, *args, **kwargs):
|
|
||||||
for engine in self.values():
|
|
||||||
engine.dispatch_attribute(*args, **kwargs)
|
|
||||||
|
|
||||||
def save_checkpoint(self, tag=None):
|
class EvalFn(Protocol):
|
||||||
if not tag:
|
def __call__(self, *, engines: Engines):
|
||||||
tag = self.cfg.trainer.save_tag
|
...
|
||||||
tag = tag.lower()
|
|
||||||
if tag[:2] == "it" or tag[:4] == "step":
|
|
||||||
tag = self.global_step
|
|
||||||
|
|
||||||
self.cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
for name, engine in self.items():
|
|
||||||
engine.save_checkpoint(self.cfg.ckpt_dir / name, tag=tag)
|
|
||||||
|
|
||||||
def load_checkpoint(self, tag=None):
|
class Logger(Protocol):
|
||||||
if not tag:
|
def __call__(self, *, data: dict):
|
||||||
tag = self.cfg.trainer.load_tag
|
...
|
||||||
|
|
||||||
for name, engine in self.items():
|
|
||||||
load_dir = self.cfg.ckpt_dir / name
|
|
||||||
engine.load_checkpoint(
|
|
||||||
tag=tag,
|
|
||||||
load_dir=load_dir,
|
|
||||||
load_module_strict=self.cfg.trainer.strict_loading,
|
|
||||||
load_optimizer_states=self.cfg.trainer.load_states,
|
|
||||||
load_lr_scheduler_states=self.cfg.trainer.load_states,
|
|
||||||
load_module_only=False, # not self.cfg.trainer.load_states,
|
|
||||||
)
|
|
||||||
if self.cfg.trainer.restart_step_count:
|
|
||||||
engine.global_steps = 0
|
|
||||||
|
|
||||||
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
|
@cache
|
||||||
if self.cfg.hyperparameters.scheduler_type == "":
|
def _get_stdin_selector():
|
||||||
self.set_lr(self.cfg.hyperparameters.learning_rate)
|
selector = selectors.DefaultSelector()
|
||||||
|
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
||||||
|
return selector
|
||||||
|
|
||||||
self._update_global_step()
|
|
||||||
|
|
||||||
def set_lr(self, lr):
|
def _non_blocking_input():
|
||||||
try:
|
global _command
|
||||||
for engine in self.values():
|
l = [""]
|
||||||
if hasattr(engine.optimizer, 'param_groups'):
|
if is_global_leader():
|
||||||
print(engine.optimizer.param_groups)
|
s = ""
|
||||||
for param_group in engine.optimizer.param_groups:
|
selector = _get_stdin_selector()
|
||||||
param_group['lr'] = lr
|
events = selector.select(timeout=0)
|
||||||
else:
|
for key, _ in events:
|
||||||
engine.optimizer.set_lr(lr)
|
s: str = key.fileobj.readline().strip()
|
||||||
except Exception as e:
|
_logger.info(f'Get stdin "{s}".')
|
||||||
print(str(e))
|
l[0] = s
|
||||||
|
broadcast_object_list(l, src=0)
|
||||||
|
_command = l[0]
|
||||||
|
return _command
|
||||||
|
|
||||||
def _update_global_step(self):
|
|
||||||
for engine in self.values():
|
|
||||||
self._global_step = max(self._global_step, engine.global_step)
|
|
||||||
|
|
||||||
def eval(self):
|
def _make_infinite_epochs(dl):
|
||||||
for engine in self.values():
|
while True:
|
||||||
engine.eval()
|
_logger.info("New epoch starts.")
|
||||||
|
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
||||||
|
|
||||||
def train(self):
|
|
||||||
for engine in self.values():
|
|
||||||
engine.train()
|
|
||||||
|
|
||||||
def step(self, feeder: TrainFeeder, batch):
|
@local_leader_only(default=None)
|
||||||
total_elapsed_time = 0
|
def logger(data):
|
||||||
|
return _logger.info(json.dumps(data, default=str))
|
||||||
|
|
||||||
stats: Any = dict()
|
|
||||||
|
|
||||||
if self.cfg.trainer.gc_mode == 'step':
|
def seed(seed):
|
||||||
do_gc()
|
# Set up random seeds, after fork()
|
||||||
|
random.seed(seed + global_rank())
|
||||||
|
np.random.seed(seed + global_rank())
|
||||||
|
torch.manual_seed(seed + global_rank())
|
||||||
|
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
|
||||||
|
|
||||||
for name, engine in self.items():
|
def train(
|
||||||
torch.cuda.synchronize()
|
engines_loader: EnginesLoader,
|
||||||
if self.cfg.trainer.gc_mode == 'substep':
|
train_dl: DataLoader,
|
||||||
do_gc()
|
train_feeder: TrainFeeder,
|
||||||
|
eval_fn: EvalFn,
|
||||||
|
logger: Logger = logger,
|
||||||
|
):
|
||||||
|
engines = engines_loader()
|
||||||
|
cfg = engines.cfg
|
||||||
|
|
||||||
start_time = time.time()
|
"""
|
||||||
|
if is_local_leader():
|
||||||
|
cfg.dump()
|
||||||
|
_logger.info(cfg)
|
||||||
|
"""
|
||||||
|
|
||||||
tries = 4
|
# Setup global engines
|
||||||
n_ooms = torch.zeros([], device=self.cfg.device)
|
global _engines
|
||||||
if self.cfg.trainer.aggressive_optimizations:
|
_engines = engines
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
|
||||||
# engine = engine.to(torch.cuda.current_device())
|
|
||||||
|
|
||||||
while tries >= 0:
|
events = []
|
||||||
try:
|
|
||||||
maybe_loss_and_engine_stats = feeder( engines=self, batch=batch, name=name )
|
|
||||||
break
|
|
||||||
except RuntimeError as e:
|
|
||||||
print("Forward", str(e))
|
|
||||||
|
|
||||||
if "out of memory" not in str(e):
|
eval_fn = global_leader_only(eval_fn)
|
||||||
self.save_checkpoint()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# shrink batch size until it's happy
|
# Pre-loop command
|
||||||
for k in batch:
|
command = _non_blocking_input()
|
||||||
batch[k] = batch[k][:-1]
|
if command in ["eval", "eval_quit"]:
|
||||||
|
engines.eval()
|
||||||
|
eval_fn(engines=engines)
|
||||||
|
engines.train()
|
||||||
|
if command in ["quit", "eval_quit"]:
|
||||||
|
return
|
||||||
|
|
||||||
if tries <= 0:
|
last_save_step = engines.global_step
|
||||||
# trigger OOM
|
last_eval_step = 0
|
||||||
n_ooms += 1
|
|
||||||
else:
|
|
||||||
# also do GC
|
|
||||||
do_gc()
|
|
||||||
continue
|
|
||||||
|
|
||||||
all_reduce(n_ooms)
|
# Training loop
|
||||||
if n_ooms.item() > 0:
|
for batch in _make_infinite_epochs(train_dl):
|
||||||
self.save_checkpoint()
|
if engines.global_step >= cfg.trainer.iterations:
|
||||||
raise RuntimeError("Out of memory during forward pass!")
|
break
|
||||||
|
|
||||||
# Here we allow skip optimizers. It's useful when, for example,
|
#batch = to_device(batch, torch.cuda.current_device())
|
||||||
# skipping discriminators in the begining of GAN training.
|
stats = engines.step(feeder=train_feeder, batch=batch)
|
||||||
if maybe_loss_and_engine_stats is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
loss, engine_stats = maybe_loss_and_engine_stats
|
|
||||||
|
|
||||||
n_ooms = torch.zeros([], device=self.cfg.device)
|
iteration = stats['global_step'] # * cfg.hyperparameters.gradient_accumulation_steps
|
||||||
|
stats['it'] = iteration
|
||||||
if self.cfg.trainer.aggressive_optimizations:
|
stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(train_dl)
|
||||||
batch = to_device(batch, 'cpu')
|
|
||||||
|
|
||||||
try:
|
stats['batch'] = {
|
||||||
engine.backward(loss)
|
'size': stats['batch_size'],
|
||||||
except RuntimeError as e:
|
'id': batch['spkr_id'],
|
||||||
print("Backwards:", str(e))
|
'index': [ index for index in batch['index'] ],
|
||||||
|
'text_len': [ text.shape[0] for text in batch['text'] ],
|
||||||
|
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
||||||
|
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||||
|
}
|
||||||
|
|
||||||
if "out of memory" not in str(e):
|
del stats['batch_size']
|
||||||
self.save_checkpoint()
|
del stats['wall_time']
|
||||||
raise e
|
del stats['global_step']
|
||||||
|
|
||||||
n_ooms += 1
|
|
||||||
|
|
||||||
all_reduce(n_ooms)
|
elapsed_time = stats.get("elapsed_time", 0)
|
||||||
if n_ooms.item() > 0:
|
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||||
self.save_checkpoint()
|
|
||||||
raise RuntimeError("Out of memory during backwards pass!")
|
|
||||||
|
|
||||||
engine.step()
|
command = _non_blocking_input()
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
total_elapsed_time += elapsed_time
|
|
||||||
|
|
||||||
stats.update(
|
if "@" in command:
|
||||||
flatten_dict(
|
what, when = command.split("@")
|
||||||
{
|
try:
|
||||||
name.split("-")[0]: dict(
|
events.append((what, int(when)))
|
||||||
loss=loss.item(),
|
_logger.info(f"Event {command} registered.")
|
||||||
lr=engine.get_lr()[0],
|
except Exception as e:
|
||||||
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
|
_logger.error(e)
|
||||||
elapsed_time=elapsed_time,
|
command = ""
|
||||||
engine_step=engine.global_step,
|
|
||||||
**engine_stats,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
del loss
|
|
||||||
# engine = engine.to('cpu')
|
|
||||||
|
|
||||||
self._update_global_step()
|
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
|
||||||
stats["batch_size"] = len(batch["text"])
|
events = [e for e in events if e[1] >= engines.global_step]
|
||||||
stats["elapsed_time"] = total_elapsed_time
|
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
|
||||||
stats["wall_time"] = time.time()
|
|
||||||
stats["global_step"] = self.global_step
|
|
||||||
|
|
||||||
return stats
|
for command in commands:
|
||||||
|
if command in ["event show", "event"]:
|
||||||
|
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
|
||||||
|
_logger.info(msg)
|
||||||
|
|
||||||
|
if command == "event clear":
|
||||||
|
events.clear()
|
||||||
|
|
||||||
|
if "time" in command:
|
||||||
|
target_iter = cfg.trainer.iterations
|
||||||
|
if " to " in command:
|
||||||
|
try:
|
||||||
|
target_iter = int(command.split(" to ")[-1])
|
||||||
|
except Exception as e:
|
||||||
|
_logger.error(e)
|
||||||
|
remaining_iters = target_iter - engines.global_step + 1
|
||||||
|
remaining_time = int(remaining_iters * elapsed_time)
|
||||||
|
_logger.info(humanize.precisedelta(remaining_time))
|
||||||
|
|
||||||
|
if "lr" in command:
|
||||||
|
rate = float(command.split(" ")[-1])
|
||||||
|
engines.set_lr(rate)
|
||||||
|
print("Updating LR to:", rate)
|
||||||
|
|
||||||
|
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||||
|
|
||||||
|
saving_commands = ["save"]
|
||||||
|
|
||||||
|
if cfg.trainer.save_on_quit:
|
||||||
|
saving_commands.append("quit")
|
||||||
|
|
||||||
|
if engines.global_step != last_save_step:
|
||||||
|
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||||
|
engines.save_checkpoint()
|
||||||
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
if engines.global_step != last_eval_step:
|
||||||
|
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||||
|
do_gc()
|
||||||
|
|
||||||
|
engines.eval()
|
||||||
|
eval_fn(engines=engines)
|
||||||
|
engines.train()
|
||||||
|
last_eval_step = engines.global_step
|
||||||
|
|
||||||
|
if command in ["quit"]:
|
||||||
|
return
|
Loading…
Reference in New Issue
Block a user