vall-e/vall_e/engines/base.py

414 lines
11 KiB
Python
Raw Normal View History

2023-08-04 01:26:36 +00:00
from torch import Tensor
from typing import Any, Protocol
Stats = dict[str, float]
class TrainFeeder(Protocol):
def __call__(
self, *, engine: "Engine", batch: Any
) -> None | tuple[Tensor, Stats]:
...
2023-08-04 02:39:00 +00:00
def default_feeder(engine, batch):
2023-08-04 01:26:36 +00:00
if isinstance(batch, list):
engine( *batch )
elif isinstance(batch, dict):
engine( **batch )
else:
engine( batch )
losses = engine.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
return loss, stats
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader
2023-08-04 01:26:36 +00:00
import logging
import time
import torch
import torch.distributed
import os
from torch import Tensor
from torch.distributed import all_reduce
from typing import Any, Protocol
from functools import cached_property
2023-08-04 01:26:36 +00:00
from .base import TrainFeeder
_logger = logging.getLogger(__name__)
2023-08-05 03:22:15 +00:00
if not distributed_initialized() and cfg.trainer.backend == "local":
init_distributed(torch.distributed.init_process_group)
2023-08-04 01:26:36 +00:00
# A very naive engine implementation using barebones PyTorch
class Engine():
def __init__(self, *args, **kwargs):
if '_cfg' in kwargs:
self._cfg = kwargs['_cfg']
kwargs.pop("_cfg")
2023-08-05 03:22:15 +00:00
self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype)
2023-08-04 01:26:36 +00:00
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = 0
self.micro_steps = 0
self.gradient_accumulation_steps = cfg.hyperparameters.gradient_accumulation_steps
def freeze(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
@property
def micro_step(self):
return self.micro_steps
def train_batch_size(self):
return cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def save_checkpoint(self, save_dir, tag ):
2023-08-05 02:17:30 +00:00
save_path = save_dir / tag / "state.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
2023-08-04 01:26:36 +00:00
torch.save({
"global_step": self.global_step,
"micro_step": self.micro_step,
"module": self.module.state_dict(),
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
2023-08-05 02:17:30 +00:00
}, save_path)
2023-08-04 01:26:36 +00:00
2023-08-05 03:22:15 +00:00
open(save_dir / "latest", 'w').write( tag )
2023-08-05 02:17:30 +00:00
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True):
if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
return
tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth"
if not load_path.exists():
return
state = torch.load(load_path)
2023-08-05 03:22:15 +00:00
self.global_steps = state['global_step']
self.micro_steps = state['micro_step']
2023-08-04 01:26:36 +00:00
self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer'])
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
def eval(self):
return self.module.eval()
def train(self):
return self.module.train()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self.module
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@cached_property
def device(self):
return next(self.module.parameters()).device
2023-08-04 01:26:36 +00:00
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
def backward(self, loss):
return (loss / self.gradient_accumulation_steps).backward()
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
self.global_steps += 1
self.optimizer.step()
self.optimizer.zero_grad()
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):
return 0.0
def traverse(self, *args, **kwargs):
self.forward(*args, **kwargs)
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss)
self.step()
return stats
# and now to ignore everything from the above
class Engines(dict[str, Engine]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setup()
def setup(self):
self._global_step = 0
self._micro_step = 0
@property
def global_step(self):
return self._global_step
@property
def micro_step(self):
return self._micro_step
def gather_attribute(self, *args, **kwargs):
ret = {}
for engine in self.values():
ret |= engine.gather_attribute(*args, **kwargs)
return ret
def dispatch_attribute(self, *args, **kwargs):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def save_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.save_tag
tag = tag.lower()
if tag[:2] == "it" or tag[:4] == "step":
tag = f'{self.global_step}'
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items():
save_dir = cfg.ckpt_dir / name
engine.save_checkpoint(save_dir, tag=tag)
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
checkpoints = [ d for d in list(save_dir.glob("*")) if d.is_dir() ]
checkpoints.sort(key=lambda x: x.stat().st_mtime)
checkpoints = checkpoints[:-cfg.trainer.keep_last_checkpoints]
for d in checkpoints:
if not d.is_dir() or not d.exists():
continue
print("Removing", d)
for p in d.iterdir():
p.unlink()
d.rmdir()
2023-08-04 01:26:36 +00:00
def load_checkpoint(self, tag=None):
if not tag:
tag = cfg.trainer.load_tag
for name, engine in self.items():
load_dir = cfg.ckpt_dir / name
engine.load_checkpoint(
tag=tag,
load_dir=load_dir,
load_module_strict=cfg.trainer.strict_loading,
load_optimizer_states=cfg.trainer.load_states,
load_lr_scheduler_states=cfg.trainer.load_states,
)
if cfg.trainer.restart_step_count:
engine.global_steps = 0
# update the LR because for some god awful reason it gets overwritten when loading from a checkpoint but only when it's not using a scheduler
if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate)
self._update_global_step()
self._update_micro_step()
def set_lr(self, lr):
for engine in self.values():
engine.set_lr(lr)
def _update_global_step(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
def _update_micro_step(self):
for engine in self.values():
self._micro_step = max(self._micro_step, engine.micro_step)
def train_batch_size(self):
batch_size = 0
for engine in self.values():
batch_size = max(batch_size, engine.train_batch_size())
def eval(self):
for engine in self.values():
engine.eval()
def train(self):
for engine in self.values():
engine.train()
def traverse(self):
stats = {}
for name, engine in self.items():
stat = engine.traverse()
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feeder):
2023-08-04 01:26:36 +00:00
total_elapsed_time = 0
stats: Any = dict()
if cfg.trainer.gc_mode == 'step':
do_gc()
for name, engine in self.items():
device = engine.device
2023-08-04 01:26:36 +00:00
if cfg.trainer.gc_mode == 'substep':
do_gc()
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=device)
2023-08-04 01:26:36 +00:00
batch = to_device(batch, device)
2023-08-04 01:26:36 +00:00
2023-08-04 01:36:19 +00:00
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch )
else:
while tries >= 0:
try:
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
all_reduce(n_ooms)
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
2023-08-04 01:26:36 +00:00
if res is None:
continue
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=device)
2023-08-04 01:26:36 +00:00
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
2023-08-04 01:36:19 +00:00
if not cfg.trainer.check_for_oom:
2023-08-04 01:26:36 +00:00
engine.backward(loss)
2023-08-04 01:36:19 +00:00
else:
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
n_ooms += 1
2023-08-04 01:26:36 +00:00
2023-08-04 01:36:19 +00:00
all_reduce(n_ooms)
if n_ooms.item() > 0:
2023-08-04 01:26:36 +00:00
self.save_checkpoint()
2023-08-04 01:36:19 +00:00
raise RuntimeError("Out of memory during backwards pass!")
2023-08-04 01:26:36 +00:00
engine.step()
#torch.cuda.synchronize()
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
loss=loss.item(),
lr=engine.get_lr()[0],
grad_norm=engine.get_global_grad_norm(), # This norm is delayed but global and avoids extra computation
elapsed_time=elapsed_time,
engine_step=engine.global_step,
**engine_stats,
)
}
),
)
self._update_global_step()
self._update_micro_step()
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time()
stats["global_step"] = self.global_step
return stats