added ability to mark models as disabled for training, and hotloading them for eval/validation (useful if training only one model, or training a model per GPU)

This commit is contained in:
mrq 2023-08-27 12:26:12 -05:00
parent 165a1154e0
commit 87c4bfedba
8 changed files with 107 additions and 60 deletions

View File

@ -161,6 +161,7 @@ class Model:
prom_levels: int = 8 prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
arch_type: str = "transformer" arch_type: str = "transformer"
training: bool = True
@property @property
def scale(self): def scale(self):
@ -215,10 +216,11 @@ class Model:
@dataclass() @dataclass()
class Models: class Models:
_max_levels: int = 0 _max_levels: int = 0
_prom_levels: int = 1
_models: list[Model] = field(default_factory=lambda: [ _models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8), Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8), Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True),
]) ])
def get(self, name=None): def get(self, name=None):
@ -241,7 +243,7 @@ class Models:
@property @property
def prom_levels(self): def prom_levels(self):
prom_levels = 1 prom_levels = self._prom_levels
for model in self._models: for model in self._models:
prom_levels = max(prom_levels, model.prom_levels) prom_levels = max(prom_levels, model.prom_levels)
return prom_levels return prom_levels
@ -279,6 +281,8 @@ class Evaluation:
ar_temperature: float = 1.0 ar_temperature: float = 1.0
nar_temperature: float = 0.2 nar_temperature: float = 0.2
load_disabled_engines: bool = True
@dataclass() @dataclass()
class DeepSpeed: class DeepSpeed:
zero_optimization_level: int = 0 zero_optimization_level: int = 0
@ -407,6 +411,8 @@ class Trainer:
aggressive_optimizations: bool = False aggressive_optimizations: bool = False
check_for_oom: bool = True check_for_oom: bool = True
load_disabled_engines: bool = False
gc_mode: str | None = None gc_mode: str | None = None
weight_dtype: str = "float16" weight_dtype: str = "float16"

View File

@ -459,7 +459,7 @@ class Dataset(_Dataset):
return dict( return dict(
index=index, index=index,
path=path, path=Path(path),
spkr_name=spkr_name, spkr_name=spkr_name,
spkr_id=spkr_id, spkr_id=spkr_id,
task=task, task=task,

View File

@ -8,4 +8,4 @@ if cfg.trainer.backend == "deepspeed":
elif cfg.trainer.backend == "local": elif cfg.trainer.backend == "local":
from .base import Engine from .base import Engine
from .base import Engines, TrainFeeder, default_feeder from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine

View File

@ -74,6 +74,10 @@ class Engine():
p.requires_grad_(True) p.requires_grad_(True)
self._frozen_params.clear() self._frozen_params.clear()
@property
def _training(self):
return self._cfg.training
@property @property
def global_step(self): def global_step(self):
return self.global_steps return self.global_steps
@ -82,7 +86,8 @@ class Engine():
def micro_step(self): def micro_step(self):
return self.micro_steps return self.micro_steps
def train_batch_size(self): @property
def batch_size(self):
return cfg.hyperparameters.batch_size return cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
@ -137,7 +142,10 @@ class Engine():
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs) self.module = self.module.to(*args, **kwargs)
return self.module if self.optimizer:
self.optimizer = self.optimizer.to(*args, **kwargs)
return self
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
@ -199,9 +207,7 @@ class Engines(dict[str, Engine]):
def setup(self): def setup(self):
self._global_step = 0 self._global_step = 0
self._micro_step = 0 self._micro_step = 0
self._batch_size = 0
for name, engine in self.items():
engine.name = name
@property @property
def global_step(self): def global_step(self):
@ -211,6 +217,10 @@ class Engines(dict[str, Engine]):
def micro_step(self): def micro_step(self):
return self._micro_step return self._micro_step
@property
def batch_size(self):
return self._batch_size
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
ret = {} ret = {}
for engine in self.values(): for engine in self.values():
@ -242,6 +252,9 @@ class Engines(dict[str, Engine]):
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True) cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
for name, engine in self.items(): for name, engine in self.items():
if not engine._training:
continue
save_dir = cfg.ckpt_dir / name save_dir = cfg.ckpt_dir / name
try: try:
engine.save_checkpoint(save_dir, tag=tag) engine.save_checkpoint(save_dir, tag=tag)
@ -282,25 +295,19 @@ class Engines(dict[str, Engine]):
if cfg.hyperparameters.scheduler_type == "": if cfg.hyperparameters.scheduler_type == "":
self.set_lr(cfg.hyperparameters.learning_rate) self.set_lr(cfg.hyperparameters.learning_rate)
self._update_global_step() self._update()
self._update_micro_step()
def set_lr(self, lr): def set_lr(self, lr):
for engine in self.values(): for engine in self.values():
if not engine._training:
continue
engine.set_lr(lr) engine.set_lr(lr)
def _update_global_step(self): def _update(self):
for engine in self.values(): for engine in self.values():
self._global_step = max(self._global_step, engine.global_step) self._global_step = max(self._global_step, engine.global_step)
def _update_micro_step(self):
for engine in self.values():
self._micro_step = max(self._micro_step, engine.micro_step) self._micro_step = max(self._micro_step, engine.micro_step)
self._batch_size = max(self._batch_size, engine.batch_size)
def train_batch_size(self):
batch_size = 0
for engine in self.values():
batch_size = max(batch_size, engine.train_batch_size())
def eval(self): def eval(self):
for engine in self.values(): for engine in self.values():
@ -325,8 +332,10 @@ class Engines(dict[str, Engine]):
if cfg.trainer.gc_mode == 'step': if cfg.trainer.gc_mode == 'step':
do_gc() do_gc()
for name, engine in self.items(): for name, engine in self.items():
if not engine._training:
continue
device = engine.device device = engine.device
if cfg.trainer.gc_mode == 'substep': if cfg.trainer.gc_mode == 'substep':
@ -424,9 +433,9 @@ class Engines(dict[str, Engine]):
), ),
) )
self._update_global_step() self._update()
self._update_micro_step()
stats["batch_size"] = self.train_batch_size() # len(batch["text"]) stats["batch_size"] = self.batch_size
stats["elapsed_time"] = total_elapsed_time stats["elapsed_time"] = total_elapsed_time
stats["wall_time"] = time.time() stats["wall_time"] = time.time()
stats["global_step"] = self.global_step stats["global_step"] = self.global_step

View File

@ -51,6 +51,10 @@ class Engine(DeepSpeedEngine):
for p in self._frozen_params: for p in self._frozen_params:
p.requires_grad_(True) p.requires_grad_(True)
self._frozen_params.clear() self._frozen_params.clear()
@property
def _training(self):
return self._cfg.training
@property @property
def global_step(self): def global_step(self):
@ -60,6 +64,10 @@ class Engine(DeepSpeedEngine):
def micro_step(self): def micro_step(self):
return self.micro_steps return self.micro_steps
@property
def batch_size(self):
return cfg.hyperparameters.batch_size
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs) return gather_attribute(self.module, *args, **kwargs)

View File

@ -41,7 +41,7 @@ def train_feeder(engine, batch):
return loss, stats return loss, stats
@torch.inference_mode() @torch.inference_mode()
def run_eval(engines, eval_name, dl): def run_eval(engines, disabled_engines, eval_name, dl):
engines_stats = { engines_stats = {
'eval': eval_name 'eval': eval_name
} }
@ -51,11 +51,23 @@ def run_eval(engines, eval_name, dl):
names = [] names = []
for name, engine in engines.items(): for name, engine in engines.items():
names.append(name)
if name[:2] == "ar": if name[:2] == "ar":
AR = engine AR = engine
elif name[:3] == "nar": elif name[:3] == "nar":
NAR = engine NAR = engine
else:
continue
names.append(name)
# hotload the missing models
for name, engine in disabled_engines.items():
if AR is None and name[:2] == "ar":
AR = engine
elif NAR is None and name[:3] == "nar":
NAR = engine
else:
continue
names.append(name)
stats = defaultdict(list) stats = defaultdict(list)
stats['loss'] = [] stats['loss'] = []
@ -148,13 +160,18 @@ def main():
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
def eval_fn(engines): def eval_fn(engines):
disabled_engines = load_engines(invert=True) if cfg.evaluation.load_disabled_engines else {}
try: try:
run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, disabled_engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl) run_eval(engines, disabled_engines, "val", val_dl)
except Exception as e: except Exception as e:
print("Error occurred while performing eval:", str(e)) print("Error occurred while performing eval:", str(e))
print(traceback.format_exc()) print(traceback.format_exc())
if len(disabled_engines.keys()):
for name, engine in disabled_engines.items():
engine = engine.to("cpu")
del disabled_engines
qnt.unload_model() qnt.unload_model()
do_gc() do_gc()

View File

@ -28,7 +28,7 @@ from .distributed import (
local_leader_only, local_leader_only,
) )
from ..engines import Engine, Engines, TrainFeeder, default_feeder from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
from ..models import get_models from ..models import get_models
from .utils import to_device, do_gc from .utils import to_device, do_gc
@ -36,36 +36,28 @@ from ..utils import wrapper as ml
from ..data import get_phone_symmap # should decouple from this trainer script from ..data import get_phone_symmap # should decouple from this trainer script
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_engines: Engines
_command: str _command: str
def get_global_step(): def load_engines(invert=False):
try:
return _engines.global_step
except:
return None
def get_micro_step():
try:
return _engines.micro_step
except:
return None
def get_cmd():
try:
return _command
except:
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
get_iteration = get_global_step
def load_engines():
models = get_models(cfg.models.get()) models = get_models(cfg.models.get())
engines = dict() engines = dict()
for name in models: for name, model in models.items():
model = models[name] # load only the models for training initially
# loads disabled models at evaluation time (to load updated weights if training separately)
# I'm sure there's a more elegant solution to this
if cfg.evaluation.load_disabled_engines:
if not invert and not model._cfg.training:
continue
if invert and model._cfg.training:
continue
# load only the models for training initially
# if load_disabled_engines, then models not marked for training will be loaded but ignored
# DeepSpeed has some weird quirks where loading an engine and moving it to CPU will have a memory leak or something
# I recommend not using this pathway
elif not cfg.trainer.load_disabled_engines:
if model._cfg.training:
continue
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
@ -82,7 +74,11 @@ def load_engines():
weight_decay=0.01, weight_decay=0.01,
) )
if cfg.trainer.load_state_dict: if not model._cfg.training:
optimizer = None
lr_scheduler = None
if cfg.trainer.load_state_dict or not model._cfg.training:
load_path = cfg.ckpt_dir / name / "fp32.pth" load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path) state = torch.load(load_path)
# exporting the model from the zero_to_fp32.py exports the actual module's dict # exporting the model from the zero_to_fp32.py exports the actual module's dict
@ -109,18 +105,18 @@ def load_engines():
# copy weights from the dict into the old portion # copy weights from the dict into the old portion
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :] model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
# reuse additional levels, probably bad
for n in range(o_resp_tokens, n_resp_tokens):
model.resps_emb.weight.data[n] = model.resps_emb.weight.data[o_resp_tokens-1]
# copy the full tensors back # copy the full tensors back
state['resps_emb.weight'] = model.resps_emb.weight state['resps_emb.weight'] = model.resps_emb.weight
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
engines[name] = Engine( # use base engine because DeepSpeed memory leaks
engines[name] = (Engine if model._cfg.training else _Engine)(
#engines[name] = Engine(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
_cfg=model._cfg, _cfg=model._cfg,
) )
@ -130,6 +126,8 @@ def load_engines():
if not cfg.trainer.load_state_dict: if not cfg.trainer.load_state_dict:
engines.load_checkpoint() engines.load_checkpoint()
do_gc()
return engines return engines
class EvalFn(Protocol): class EvalFn(Protocol):

View File

@ -17,12 +17,21 @@ from torch import Tensor, nn
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload from typing import Callable, TypeVar, overload
try:
from deepspeed.runtime.utils import empty_cache
except Exception as e:
print(str(e))
def empty_cache():
...
T = TypeVar("T") T = TypeVar("T")
def do_gc(): def do_gc():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
empty_cache()
def flatten_dict(d): def flatten_dict(d):
records = pd.json_normalize(d).to_dict(orient="records") records = pd.json_normalize(d).to_dict(orient="records")
return records[0] if records else {} return records[0] if records else {}