vall-e/vall_e/engines/base.py
2024-05-04 21:03:46 -05:00

502 lines
14 KiB
Python
Executable File

from torch import Tensor
from typing import Any, Protocol
Stats = dict[str, float]
class TrainFeeder(Protocol):
def __call__(
self, *, engine: "Engine", batch: Any
) -> None | tuple[Tensor, Stats]:
...
def default_feeder(engine, batch):
if isinstance(batch, list):
engine( *batch )
elif isinstance(batch, dict):
engine( **batch )
else:
engine( batch )
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
return loss, stats
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
import logging
import time
import torch
import torch.distributed
import os
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from functools import cached_property
from .base import TrainFeeder
from ..utils import wrapper as ml
_logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local": # and world_size() > 1:
init_distributed(torch.distributed.init_process_group)
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
if '_cfg' in kwargs:
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype)
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = kwargs.pop("global_steps", 0)
self.micro_steps = kwargs.pop("micro_steps", 0)
self.global_samples = kwargs.pop("global_samples", 0)
self.tokens_processed = kwargs.pop("tokens_processed", 0)
self._frozen_params = set()
def freeze(self, freeze_all=True):
# set to freeze
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params):
param.requires_grad_(False)
self._frozen_params.add(param)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def _training(self):
if not hasattr(self, "_cfg"):
return True
return self._cfg.training
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
@property
def batch_size(self):
return cfg.hyperparameters.batch_size
@property
def gradient_accumulation_steps(self):
return cfg.hyperparameters.gradient_accumulation_steps
@property
def gradient_clipping(self):
return cfg.hyperparameters.gradient_clipping
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ):
save_path = save_dir / tag / "state.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
"stats": {
"global_step": self.global_step,
"micro_step": self.micro_step,
"global_samples": self.global_samples,
"tokens_processed": self.tokens_processed,
}
}, save_path)
open(save_dir / "latest", 'w').write( tag )
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
return
tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth"
if not load_path.exists():
return
state = torch.load(load_path, map_location=torch.device(cfg.device))
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer'], map_location=torch.device(cfg.device))
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler'], map_location=torch.device(cfg.device))
def eval(self):
return self.module.eval()
def train(self):
return self.module.train()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
if self.optimizer:
self.optimizer = self.optimizer.to(*args, **kwargs)
return self
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@cached_property
def device(self):
return next(self.module.parameters()).device
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
def backward(self, loss):
return (loss / self.gradient_accumulation_steps).backward()
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
self.global_samples += self.batch_size
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
self.global_steps += 1
self.optimizer.step()
self.optimizer.zero_grad()
self._get_grad_norm()
def _get_grad_norm(self):
t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else 0
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
if 'd_coeff' in param_group:
lrs.append(param_group['d_coeff'])
elif 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
if 'd_coeff' in param_group:
param_group['d_coeff'] = lr
elif 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):
return self._global_grad_norm
def traverse(self, *args, **kwargs):
with ml.autocast():
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss)
self.step()
return stats
# and now to ignore everything from the above
class Engines(dict[str, Engine]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
def setup(self):
self._global_step = 0
self._micro_step = 0
self._batch_size = 0
self._global_samples = 0
@property
def global_step(self):
return self._global_step
@property
def micro_step(self):
return self._micro_step
@property
def batch_size(self):
return self._batch_size
@property
def global_samples(self):
return self._global_samples
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}):
for name, engine in self.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
state_dict = {
'module': engine.module.state_dict(),
"stats": {
"global_step": engine.global_step,
"micro_step": engine.micro_step,
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata
}
torch.save(state_dict, outpath)
print(f"Exported {name} to {outpath}")
def save_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.save_tag
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = f'{self.global_step}'
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
if not engine._training:
continue
save_dir = cfg.ckpt_dir / name
try:
engine.save_checkpoint(save_dir, tag=tag)
except Exception as e:
print(f'Failed to save checkpoint for engine {name}:', str(e))
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ]
checkpoints.sort(key=lambda x: x.stat().st_mtime)
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
for d in checkpoints:
if not d.is_dir() or not d.exists():
continue
print("Removing", d)
for p in d.iterdir():
p.unlink()
d.rmdir()
def load_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.load_tag
for name, engine in self.items():
load_dir = cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states,
load_module_only=cfg.trainer.load_module_only,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0
engine.mocro_step = 0
engine.global_samples = 0
engine.tokens_processed = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate)
self._update()
def set_lr(self, lr):
for engine in self.values():
if not engine._training:
continue
engine.set_lr(lr)
def _update(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
self._micro_step = max(self._micro_step, engine.micro_step)
self._batch_size = max(self._batch_size, engine.batch_size)
self._global_samples = max(self._global_samples, engine.global_samples)
def eval(self):
for engine in self.values():
engine.eval()
def train(self):
for engine in self.values():
engine.train()
def traverse(self):
stats = {}
for name, engine in self.items():
stat = engine.traverse()
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0
stats: Any = dict()
if cfg.trainer.gc_mode == 'step':
do_gc()
for name, engine in self.items():
if not engine._training:
continue
device = engine.device
if cfg.trainer.gc_mode == 'substep':
do_gc()
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=device)
batch = to_device(batch, device)
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch )
else:
while tries >= 0:
try:
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
if res is None:
continue
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
if not cfg.trainer.check_for_oom:
engine.backward(loss)
else:
# to-do: properly handle when one GPU throws an OOM because it just halts
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
engine.step()
#torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed,
**engine_stats,
)
}
),
)
self._update()
if len(self.keys()) > 1:
stats["elapsed_time"] = total_elapsed_time
stats["it"] = self.global_step
return stats