vall-e/vall_e/utils/trainer.py

355 lines
10 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
"""
# https://github.com/enhuiz/pytorch-training-utilities
"""
2023-08-02 23:12:36 +00:00
import humanize
import json
2023-08-02 21:53:35 +00:00
import logging
2023-08-02 23:12:36 +00:00
import numpy as np
import random
import selectors
import sys
2023-08-02 21:53:35 +00:00
import torch
import os
2023-08-02 21:53:35 +00:00
2023-08-02 23:12:36 +00:00
from functools import cache
from torch.distributed import broadcast_object_list
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Protocol
2023-08-02 21:53:35 +00:00
2023-08-04 01:26:36 +00:00
from ..config import cfg
2023-08-24 22:05:56 +00:00
from .distributed import init_distributed, distributed_initialized, world_size
2023-08-02 23:12:36 +00:00
from .distributed import (
2023-08-04 01:26:36 +00:00
global_leader_only,
global_rank,
is_global_leader,
is_local_leader,
local_leader_only,
2023-08-02 23:12:36 +00:00
)
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
2023-08-04 01:26:36 +00:00
from ..models import get_models
2023-08-02 23:12:36 +00:00
from .utils import to_device, do_gc
2023-08-04 01:26:36 +00:00
from ..utils import wrapper as ml
from ..data import get_phone_symmap # should decouple from this trainer script
2023-08-02 21:53:35 +00:00
_logger = logging.getLogger(__name__)
2023-08-02 23:12:36 +00:00
_command: str
def load_engines(invert=False):
2023-08-04 01:26:36 +00:00
models = get_models(cfg.models.get())
engines = dict()
2023-08-02 23:12:36 +00:00
for name, model in models.items():
# load only the models for training initially
# loads disabled models at evaluation time (to load updated weights if training separately)
# I'm sure there's a more elegant solution to this
if cfg.evaluation.load_disabled_engines:
if not invert and not model._cfg.training:
continue
if invert and model._cfg.training:
continue
# load only the models for training initially
# if load_disabled_engines, then models not marked for training will be loaded but ignored
# DeepSpeed has some weird quirks where loading an engine and moving it to CPU will have a memory leak or something
# I recommend not using this pathway
elif not cfg.trainer.load_disabled_engines:
if model._cfg.training:
continue
2023-08-02 23:12:36 +00:00
2023-08-04 01:26:36 +00:00
optimizer = None
lr_scheduler = None
2023-08-02 23:12:36 +00:00
# yuck, should instead check be optimier == "adamw" and backend != "deepspeed"
# and then have ds_cfg pass in the config flag to use pytorch adamw
# I genuinely cannot validate if this ever actually gets used in DeepSpeed
2023-08-24 22:05:56 +00:00
if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"):
2023-08-04 01:26:36 +00:00
optimizer = ml.AdamW(
model.parameters(),
lr=cfg.hyperparameters.learning_rate,
betas=(0.9, 0.96),
eps=1e-07,
weight_decay=0.01,
)
2023-08-02 23:12:36 +00:00
if not model._cfg.training:
optimizer = None
lr_scheduler = None
if cfg.trainer.load_state_dict or not model._cfg.training:
2023-08-04 01:26:36 +00:00
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path)
2023-08-19 01:58:07 +00:00
# exporting the model from the zero_to_fp32.py exports the actual module's dict
# exporting with vall_e.export exports the state dict under .module
if "module" in state:
state = state["module"]
2023-08-19 01:58:07 +00:00
# should decouple the following from this trainer script
# probably with passing a fun that defaults to a lambda x: x deal
2023-08-19 01:58:07 +00:00
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
2023-08-19 01:58:07 +00:00
# copy weights from the dict into the old portion
model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :]
# copy the full tensors back
2023-08-19 01:58:07 +00:00
state['proms_emb.weight'] = model.proms_emb.weight
# extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]:
o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape
n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape
# copy weights from the dict into the old portion
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
# copy the full tensors back
state['resps_emb.weight'] = model.resps_emb.weight
2023-08-19 01:58:07 +00:00
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
2023-08-02 23:12:36 +00:00
# use base engine because DeepSpeed memory leaks
engines[name] = (Engine if model._cfg.training else _Engine)(
#engines[name] = Engine(
2023-08-04 01:26:36 +00:00
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
_cfg=model._cfg,
2023-08-04 01:26:36 +00:00
)
2023-08-02 23:12:36 +00:00
2023-08-04 01:26:36 +00:00
engines = Engines(engines)
engines.setup()
2023-08-02 23:12:36 +00:00
2023-08-04 01:26:36 +00:00
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
2023-08-02 23:12:36 +00:00
do_gc()
2023-08-04 01:26:36 +00:00
return engines
2023-08-02 23:12:36 +00:00
class EvalFn(Protocol):
2023-08-04 01:26:36 +00:00
def __call__(self, *, engines: Engines):
...
2023-08-02 23:12:36 +00:00
class Logger(Protocol):
2023-08-04 01:26:36 +00:00
def __call__(self, *, data: dict):
...
2023-08-02 23:12:36 +00:00
@cache
def _get_stdin_selector():
2023-08-04 01:26:36 +00:00
selector = selectors.DefaultSelector()
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
return selector
2023-08-02 23:12:36 +00:00
2023-08-05 03:22:15 +00:00
if os.name == "nt":
import msvcrt
_buffer = []
2023-08-02 23:12:36 +00:00
def _non_blocking_input():
2023-08-04 01:26:36 +00:00
global _command
2023-08-05 03:22:15 +00:00
global _buffer
2023-08-04 01:26:36 +00:00
l = [""]
2023-08-05 03:22:15 +00:00
def _windows():
global _buffer
if msvcrt.kbhit():
s: str = msvcrt.getch().decode('utf-8')
if s == '\r':
s = "".join(_buffer)
_buffer = []
return s
_buffer.append(s)
return ""
def _linux():
2023-08-04 01:26:36 +00:00
s = ""
selector = _get_stdin_selector()
events = selector.select(timeout=0)
for key, _ in events:
s: str = key.fileobj.readline().strip()
2023-08-05 03:22:15 +00:00
return s
if is_global_leader():
s = _windows() if os.name == 'nt' else _linux()
if s != "":
2023-08-04 01:26:36 +00:00
_logger.info(f'Get stdin "{s}".')
2023-08-05 03:22:15 +00:00
2023-08-04 01:26:36 +00:00
l[0] = s
2023-08-05 03:22:15 +00:00
2023-08-24 22:05:56 +00:00
if world_size() > 1:
broadcast_object_list(l, src=0)
2023-08-04 01:26:36 +00:00
_command = l[0]
return _command
2023-08-02 23:12:36 +00:00
def _make_infinite_epochs(dl):
2023-08-04 01:26:36 +00:00
while True:
_logger.info("New epoch starts.")
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
2023-08-02 23:12:36 +00:00
@local_leader_only(default=None)
def logger(data):
2023-08-04 01:26:36 +00:00
return _logger.info(json.dumps(data, default=str))
2023-08-02 23:12:36 +00:00
def seed(seed):
2023-08-04 01:26:36 +00:00
# Set up random seeds, after fork()
random.seed(seed + global_rank())
np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank())
2023-08-02 23:12:36 +00:00
def train(
2023-08-04 01:26:36 +00:00
train_dl: DataLoader,
train_feeder: TrainFeeder = default_feeder,
eval_fn: EvalFn = lambda x: ...,
logger: Logger = logger,
2023-08-02 23:12:36 +00:00
):
2023-08-04 01:26:36 +00:00
engines = load_engines()
"""
if is_local_leader():
cfg.dump()
_logger.info(cfg)
"""
# Setup global engines
global _engines
_engines = engines
events = []
eval_fn = global_leader_only(eval_fn)
# Pre-loop command
command = _non_blocking_input()
if command in ["eval", "eval_quit"]:
engines.eval()
eval_fn(engines=engines)
engines.train()
if command in ["quit", "eval_quit"]:
return
last_save_step = engines.global_step
last_eval_step = 0
# Training loop
for batch in _make_infinite_epochs(train_dl):
if engines.global_step >= cfg.trainer.iterations:
break
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder)
stats['it'] = stats['global_step']
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
2023-08-04 01:26:36 +00:00
stats['batch'] = {
'size': len(batch['text']),
2023-08-04 01:26:36 +00:00
'id': batch['spkr_id'],
'index': [ index for index in batch['index'] ],
'text_len': [ text.shape[0] for text in batch['text'] ],
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
}
del stats['global_step']
elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
command = _non_blocking_input()
if "@" in command:
what, when = command.split("@")
try:
events.append((what, int(when)))
_logger.info(f"Event {command} registered.")
except Exception as e:
_logger.error(e)
command = ""
# Commands are the current command plus the triggered (i.e. iteration >= trigger point) events
events = [e for e in events if e[1] >= engines.global_step]
commands = [command] + [e[0] for e in events if e[1] == engines.global_step]
for command in commands:
if command in ["event show", "event"]:
msg = "Events:\n" + "\n".join(["@".join(map(str, e)) for e in events])
_logger.info(msg)
if command == "event clear":
events.clear()
if "time" in command:
target_iter = cfg.trainer.iterations
if " to " in command:
try:
target_iter = int(command.split(" to ")[-1])
except Exception as e:
_logger.error(e)
remaining_iters = target_iter - engines.global_step + 1
remaining_time = int(remaining_iters * elapsed_time)
_logger.info(humanize.precisedelta(remaining_time))
if "lr" in command:
rate = float(command.split(" ")[-1])
try:
engines.set_lr(rate)
print("Updating LR to:", rate)
except Exception as e:
print("Failed to set LR rate to:", rate, str(e))
if "export" in command:
2023-09-04 02:27:13 +00:00
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
engines.save_checkpoint()
last_save_step = engines.global_step
if is_global_leader():
engines.export(userdata={"symmap": get_phone_symmap()})
2023-08-04 01:26:36 +00:00
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"]
export_commands = ["export"]
2023-08-04 01:26:36 +00:00
if cfg.trainer.save_on_quit:
saving_commands.append("quit")
if cfg.trainer.export_on_quit:
export_commands.append("quit")
if cfg.trainer.export_on_save:
export_commands.append("save")
2023-08-04 01:26:36 +00:00
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
2023-09-04 02:27:13 +00:00
train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt")
2023-08-04 01:26:36 +00:00
engines.save_checkpoint()
last_save_step = engines.global_step
if command in export_commands and is_global_leader():
engines.export(userdata={"symmap": get_phone_symmap()})
2023-08-04 01:26:36 +00:00
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc()
engines.eval()
eval_fn(engines=engines)
engines.train()
last_eval_step = engines.global_step
if command in ["quit"]:
return