2023-08-02 21:53:35 +00:00
|
|
|
"""
|
|
|
|
# https://github.com/enhuiz/pytorch-training-utilities
|
|
|
|
"""
|
|
|
|
|
2023-08-02 23:12:36 +00:00
|
|
|
import humanize
|
|
|
|
import json
|
2023-08-02 21:53:35 +00:00
|
|
|
import logging
|
2023-08-02 23:12:36 +00:00
|
|
|
import numpy as np
|
|
|
|
import random
|
|
|
|
import selectors
|
|
|
|
import sys
|
2023-08-02 21:53:35 +00:00
|
|
|
import torch
|
2023-08-05 20:25:41 +00:00
|
|
|
import os
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-08-02 23:12:36 +00:00
|
|
|
from functools import cache
|
|
|
|
from torch.distributed import broadcast_object_list
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from tqdm import tqdm
|
|
|
|
from typing import Protocol
|
2023-08-02 21:53:35 +00:00
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
from ..config import cfg
|
2023-08-24 22:05:56 +00:00
|
|
|
from .distributed import init_distributed, distributed_initialized, world_size
|
2023-08-02 23:12:36 +00:00
|
|
|
from .distributed import (
|
2023-08-04 01:26:36 +00:00
|
|
|
global_leader_only,
|
|
|
|
global_rank,
|
|
|
|
is_global_leader,
|
|
|
|
is_local_leader,
|
|
|
|
local_leader_only,
|
2023-08-02 23:12:36 +00:00
|
|
|
)
|
|
|
|
|
2023-08-27 17:26:12 +00:00
|
|
|
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
|
2023-08-04 01:26:36 +00:00
|
|
|
from ..models import get_models
|
|
|
|
|
2023-08-02 23:12:36 +00:00
|
|
|
from .utils import to_device, do_gc
|
2023-08-04 01:26:36 +00:00
|
|
|
from ..utils import wrapper as ml
|
2023-08-20 18:39:58 +00:00
|
|
|
from ..data import get_phone_symmap # should decouple from this trainer script
|
2023-08-02 21:53:35 +00:00
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
2023-08-02 23:12:36 +00:00
|
|
|
_command: str
|
|
|
|
|
2023-09-22 18:04:17 +00:00
|
|
|
def load_engines():
|
2023-08-04 01:26:36 +00:00
|
|
|
models = get_models(cfg.models.get())
|
|
|
|
engines = dict()
|
2023-08-02 23:12:36 +00:00
|
|
|
|
2023-08-27 17:26:12 +00:00
|
|
|
for name, model in models.items():
|
2023-08-04 01:26:36 +00:00
|
|
|
optimizer = None
|
|
|
|
lr_scheduler = None
|
2023-08-02 23:12:36 +00:00
|
|
|
|
2023-09-07 14:14:03 +00:00
|
|
|
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
2023-09-21 00:10:59 +00:00
|
|
|
optimizer_class = None
|
|
|
|
params = {
|
|
|
|
"lr": cfg.hyperparameters.learning_rate,
|
|
|
|
}
|
2023-09-07 14:14:03 +00:00
|
|
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
2023-09-21 00:10:59 +00:00
|
|
|
params["betas"] = (0.9, 0.96)
|
|
|
|
params["eps"] = 1e-07
|
|
|
|
params["weight_decay"] = 0.01
|
|
|
|
|
|
|
|
optimizer_class = ml.AdamW
|
2023-09-07 14:14:03 +00:00
|
|
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
2023-09-21 00:10:59 +00:00
|
|
|
optimizer = ml.SGD
|
2023-09-07 14:14:03 +00:00
|
|
|
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
2023-09-21 00:10:59 +00:00
|
|
|
optimizer_class = ml.Prodigy
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
|
|
|
|
|
|
|
params.update(cfg.hyperparameters.optimizer_params)
|
|
|
|
optimizer = optimizer_class(
|
|
|
|
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
|
|
|
**params,
|
|
|
|
)
|
|
|
|
|
|
|
|
# set up our LR scheduler here
|
2023-08-02 23:12:36 +00:00
|
|
|
|
2023-08-27 17:26:12 +00:00
|
|
|
if not model._cfg.training:
|
|
|
|
optimizer = None
|
|
|
|
lr_scheduler = None
|
|
|
|
|
2023-10-06 15:02:45 +00:00
|
|
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
|
|
|
if not cfg.trainer.load_state_dict and cfg.trainer.backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
|
|
|
|
print("DeepSpeed checkpoint missing, but weights found.")
|
|
|
|
cfg.trainer.load_state_dict = True
|
|
|
|
|
2023-09-21 00:10:59 +00:00
|
|
|
stats = None
|
2023-08-27 17:26:12 +00:00
|
|
|
if cfg.trainer.load_state_dict or not model._cfg.training:
|
2023-08-04 01:26:36 +00:00
|
|
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
2023-09-09 21:17:20 +00:00
|
|
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
2023-09-21 00:10:59 +00:00
|
|
|
|
|
|
|
# state dict is not just the module, extract the extra trainer details
|
|
|
|
if "stats" in state:
|
2023-09-24 00:59:00 +00:00
|
|
|
stats = state["stats"]
|
2023-09-21 00:10:59 +00:00
|
|
|
|
2023-08-18 19:47:48 +00:00
|
|
|
if "module" in state:
|
|
|
|
state = state["module"]
|
2023-08-19 20:06:33 +00:00
|
|
|
|
2023-08-19 01:58:07 +00:00
|
|
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
2023-08-02 23:12:36 +00:00
|
|
|
|
2023-09-21 00:10:59 +00:00
|
|
|
# use base engine because DeepSpeed memory leaks if it's a non-training model
|
2023-08-27 17:26:12 +00:00
|
|
|
engines[name] = (Engine if model._cfg.training else _Engine)(
|
2023-08-04 01:26:36 +00:00
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
lr_scheduler=lr_scheduler,
|
2023-08-27 17:26:12 +00:00
|
|
|
|
2023-08-19 20:06:33 +00:00
|
|
|
_cfg=model._cfg,
|
2023-09-21 00:10:59 +00:00
|
|
|
stats=stats
|
2023-08-04 01:26:36 +00:00
|
|
|
)
|
2023-08-02 23:12:36 +00:00
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
engines = Engines(engines)
|
|
|
|
engines.setup()
|
2023-08-02 23:12:36 +00:00
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
if not cfg.trainer.load_state_dict:
|
|
|
|
engines.load_checkpoint()
|
2023-08-02 23:12:36 +00:00
|
|
|
|
2023-09-07 23:19:51 +00:00
|
|
|
# freeze requested params
|
|
|
|
for name, engine in engines.items():
|
|
|
|
engine.freeze(freeze_all=False)
|
|
|
|
|
2023-08-27 17:26:12 +00:00
|
|
|
do_gc()
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
return engines
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
class EvalFn(Protocol):
|
2023-08-04 01:26:36 +00:00
|
|
|
def __call__(self, *, engines: Engines):
|
|
|
|
...
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Logger(Protocol):
|
2023-08-04 01:26:36 +00:00
|
|
|
def __call__(self, *, data: dict):
|
|
|
|
...
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
@cache
|
|
|
|
def _get_stdin_selector():
|
2023-08-04 01:26:36 +00:00
|
|
|
selector = selectors.DefaultSelector()
|
|
|
|
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
|
|
|
return selector
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
|
2023-08-05 03:22:15 +00:00
|
|
|
if os.name == "nt":
|
|
|
|
import msvcrt
|
|
|
|
_buffer = []
|
|
|
|
|
2023-08-02 23:12:36 +00:00
|
|
|
def _non_blocking_input():
|
2023-08-04 01:26:36 +00:00
|
|
|
global _command
|
2023-08-05 03:22:15 +00:00
|
|
|
global _buffer
|
2023-08-04 01:26:36 +00:00
|
|
|
l = [""]
|
2023-08-05 03:22:15 +00:00
|
|
|
|
|
|
|
def _windows():
|
|
|
|
global _buffer
|
|
|
|
|
|
|
|
if msvcrt.kbhit():
|
|
|
|
s: str = msvcrt.getch().decode('utf-8')
|
|
|
|
if s == '\r':
|
|
|
|
s = "".join(_buffer)
|
|
|
|
_buffer = []
|
|
|
|
return s
|
|
|
|
|
|
|
|
_buffer.append(s)
|
|
|
|
return ""
|
|
|
|
|
|
|
|
def _linux():
|
2023-08-04 01:26:36 +00:00
|
|
|
s = ""
|
|
|
|
selector = _get_stdin_selector()
|
|
|
|
events = selector.select(timeout=0)
|
|
|
|
for key, _ in events:
|
|
|
|
s: str = key.fileobj.readline().strip()
|
2023-08-05 03:22:15 +00:00
|
|
|
return s
|
|
|
|
|
|
|
|
if is_global_leader():
|
|
|
|
s = _windows() if os.name == 'nt' else _linux()
|
|
|
|
|
|
|
|
if s != "":
|
2023-08-04 01:26:36 +00:00
|
|
|
_logger.info(f'Get stdin "{s}".')
|
2023-08-05 03:22:15 +00:00
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
l[0] = s
|
2023-08-05 03:22:15 +00:00
|
|
|
|
2023-08-24 22:05:56 +00:00
|
|
|
if world_size() > 1:
|
|
|
|
broadcast_object_list(l, src=0)
|
2023-08-04 01:26:36 +00:00
|
|
|
_command = l[0]
|
|
|
|
return _command
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _make_infinite_epochs(dl):
|
2023-08-04 01:26:36 +00:00
|
|
|
while True:
|
|
|
|
_logger.info("New epoch starts.")
|
|
|
|
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
@local_leader_only(default=None)
|
|
|
|
def logger(data):
|
2023-08-04 01:26:36 +00:00
|
|
|
return _logger.info(json.dumps(data, default=str))
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def seed(seed):
|
2023-08-04 01:26:36 +00:00
|
|
|
# Set up random seeds, after fork()
|
|
|
|
random.seed(seed + global_rank())
|
|
|
|
np.random.seed(seed + global_rank())
|
|
|
|
torch.manual_seed(seed + global_rank())
|
2023-08-02 23:12:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
def train(
|
2023-08-04 01:26:36 +00:00
|
|
|
train_dl: DataLoader,
|
|
|
|
train_feeder: TrainFeeder = default_feeder,
|
|
|
|
eval_fn: EvalFn = lambda x: ...,
|
|
|
|
logger: Logger = logger,
|
2023-08-02 23:12:36 +00:00
|
|
|
):
|
2023-08-04 01:26:36 +00:00
|
|
|
engines = load_engines()
|
|
|
|
|
|
|
|
"""
|
|
|
|
if is_local_leader():
|
|
|
|
cfg.dump()
|
|
|
|
_logger.info(cfg)
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Setup global engines
|
|
|
|
global _engines
|
|
|
|
_engines = engines
|
|
|
|
|
|
|
|
events = []
|
|
|
|
|
|
|
|
eval_fn = global_leader_only(eval_fn)
|
|
|
|
|
|
|
|
# Pre-loop command
|
|
|
|
command = _non_blocking_input()
|
|
|
|
if command in ["eval", "eval_quit"]:
|
|
|
|
engines.eval()
|
|
|
|
eval_fn(engines=engines)
|
|
|
|
engines.train()
|
|
|
|
if command in ["quit", "eval_quit"]:
|
|
|
|
return
|
|
|
|
|
|
|
|
last_save_step = engines.global_step
|
|
|
|
last_eval_step = 0
|
|
|
|
|
|
|
|
# Training loop
|
|
|
|
for batch in _make_infinite_epochs(train_dl):
|
|
|
|
if engines.global_step >= cfg.trainer.iterations:
|
|
|
|
break
|
|
|
|
|
|
|
|
#batch = to_device(batch, torch.cuda.current_device())
|
|
|
|
stats = engines.step(batch=batch, feeder=train_feeder)
|
|
|
|
|
2023-09-03 13:03:36 +00:00
|
|
|
stats['it'] = stats['global_step']
|
|
|
|
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
stats['batch'] = {
|
2023-08-28 16:02:45 +00:00
|
|
|
'size': len(batch['text']),
|
2023-08-04 01:26:36 +00:00
|
|
|
'id': batch['spkr_id'],
|
|
|
|
'index': [ index for index in batch['index'] ],
|
|
|
|
'text_len': [ text.shape[0] for text in batch['text'] ],
|
|
|
|
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
|
|
|
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
|
|
|
}
|
|
|
|
|
|
|
|
del stats['global_step']
|
|
|
|
|
|
|
|
elapsed_time = stats.get("elapsed_time", 0)
|
|
|
|
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
|
|
|
|
|
|
|
command = _non_blocking_input()
|
|
|
|
|
|
|
|
if "@" in command:
|
|
|
|
what, when = command.split("@")
|
|
|
|
try:
|
|
|
|
events.append((what, int(when)))
|
|
|
|
_logger.info(f"Event {command} registered.")
|
|
|
|
except Exception as e:
|
|
|
|
_logger.error(e)
|
|
|
|
command = ""
|
|
|
|
|
|
|
|
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
|
|
|
|
events = [e for e in events if e[1] >= engines.global_step]
|
|
|
|
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
|
|
|
|
|
|
|
|
for command in commands:
|
|
|
|
if command in ["event show", "event"]:
|
|
|
|
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
|
|
|
|
_logger.info(msg)
|
|
|
|
|
|
|
|
if command == "event clear":
|
|
|
|
events.clear()
|
|
|
|
|
|
|
|
if "time" in command:
|
|
|
|
target_iter = cfg.trainer.iterations
|
|
|
|
if " to " in command:
|
|
|
|
try:
|
|
|
|
target_iter = int(command.split(" to ")[-1])
|
|
|
|
except Exception as e:
|
|
|
|
_logger.error(e)
|
|
|
|
remaining_iters = target_iter - engines.global_step + 1
|
|
|
|
remaining_time = int(remaining_iters * elapsed_time)
|
|
|
|
_logger.info(humanize.precisedelta(remaining_time))
|
|
|
|
|
|
|
|
if "lr" in command:
|
|
|
|
rate = float(command.split(" ")[-1])
|
2023-08-20 18:39:58 +00:00
|
|
|
try:
|
|
|
|
engines.set_lr(rate)
|
|
|
|
print("Updating LR to:", rate)
|
|
|
|
except Exception as e:
|
|
|
|
print("Failed to set LR rate to:", rate, str(e))
|
|
|
|
|
|
|
|
if "export" in command:
|
2023-09-04 02:27:13 +00:00
|
|
|
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
2023-08-20 18:39:58 +00:00
|
|
|
engines.save_checkpoint()
|
|
|
|
last_save_step = engines.global_step
|
|
|
|
|
|
|
|
if is_global_leader():
|
|
|
|
engines.export(userdata={"symmap": get_phone_symmap()})
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
|
|
|
|
|
|
|
saving_commands = ["save"]
|
2023-08-23 21:43:03 +00:00
|
|
|
export_commands = ["export"]
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
if cfg.trainer.save_on_quit:
|
|
|
|
saving_commands.append("quit")
|
|
|
|
|
2023-08-23 21:43:03 +00:00
|
|
|
if cfg.trainer.export_on_quit:
|
|
|
|
export_commands.append("quit")
|
|
|
|
|
|
|
|
if cfg.trainer.export_on_save:
|
|
|
|
export_commands.append("save")
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
if engines.global_step != last_save_step:
|
|
|
|
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
2023-09-04 02:27:13 +00:00
|
|
|
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
|
2023-08-04 01:26:36 +00:00
|
|
|
engines.save_checkpoint()
|
|
|
|
last_save_step = engines.global_step
|
2023-08-23 21:43:03 +00:00
|
|
|
|
|
|
|
if command in export_commands and is_global_leader():
|
|
|
|
engines.export(userdata={"symmap": get_phone_symmap()})
|
2023-08-04 01:26:36 +00:00
|
|
|
|
|
|
|
if engines.global_step != last_eval_step:
|
|
|
|
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
|
|
|
do_gc()
|
|
|
|
|
|
|
|
engines.eval()
|
|
|
|
eval_fn(engines=engines)
|
|
|
|
engines.train()
|
|
|
|
last_eval_step = engines.global_step
|
|
|
|
|
|
|
|
if command in ["quit"]:
|
2023-08-05 20:25:41 +00:00
|
|
|
return
|