added ability to mark models as disabled for training, and hotloading them for eval/validation (useful if training only one model, or training a model per GPU)
This commit is contained in:
parent
165a1154e0
commit
87c4bfedba
|
@ -161,6 +161,7 @@ class Model:
|
|||
prom_levels: int = 8
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
arch_type: str = "transformer"
|
||||
training: bool = True
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
|
@ -215,10 +216,11 @@ class Model:
|
|||
@dataclass()
|
||||
class Models:
|
||||
_max_levels: int = 0
|
||||
_prom_levels: int = 1
|
||||
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8),
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True),
|
||||
])
|
||||
|
||||
def get(self, name=None):
|
||||
|
@ -241,7 +243,7 @@ class Models:
|
|||
|
||||
@property
|
||||
def prom_levels(self):
|
||||
prom_levels = 1
|
||||
prom_levels = self._prom_levels
|
||||
for model in self._models:
|
||||
prom_levels = max(prom_levels, model.prom_levels)
|
||||
return prom_levels
|
||||
|
@ -279,6 +281,8 @@ class Evaluation:
|
|||
ar_temperature: float = 1.0
|
||||
nar_temperature: float = 0.2
|
||||
|
||||
load_disabled_engines: bool = True
|
||||
|
||||
@dataclass()
|
||||
class DeepSpeed:
|
||||
zero_optimization_level: int = 0
|
||||
|
@ -407,6 +411,8 @@ class Trainer:
|
|||
aggressive_optimizations: bool = False
|
||||
check_for_oom: bool = True
|
||||
|
||||
load_disabled_engines: bool = False
|
||||
|
||||
gc_mode: str | None = None
|
||||
|
||||
weight_dtype: str = "float16"
|
||||
|
|
|
@ -459,7 +459,7 @@ class Dataset(_Dataset):
|
|||
|
||||
return dict(
|
||||
index=index,
|
||||
path=path,
|
||||
path=Path(path),
|
||||
spkr_name=spkr_name,
|
||||
spkr_id=spkr_id,
|
||||
task=task,
|
||||
|
|
|
@ -8,4 +8,4 @@ if cfg.trainer.backend == "deepspeed":
|
|||
elif cfg.trainer.backend == "local":
|
||||
from .base import Engine
|
||||
|
||||
from .base import Engines, TrainFeeder, default_feeder
|
||||
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
|
@ -74,6 +74,10 @@ class Engine():
|
|||
p.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def _training(self):
|
||||
return self._cfg.training
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
@ -82,7 +86,8 @@ class Engine():
|
|||
def micro_step(self):
|
||||
return self.micro_steps
|
||||
|
||||
def train_batch_size(self):
|
||||
@property
|
||||
def batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
|
@ -137,7 +142,10 @@ class Engine():
|
|||
|
||||
def to(self, *args, **kwargs):
|
||||
self.module = self.module.to(*args, **kwargs)
|
||||
return self.module
|
||||
if self.optimizer:
|
||||
self.optimizer = self.optimizer.to(*args, **kwargs)
|
||||
|
||||
return self
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
@ -199,9 +207,7 @@ class Engines(dict[str, Engine]):
|
|||
def setup(self):
|
||||
self._global_step = 0
|
||||
self._micro_step = 0
|
||||
|
||||
for name, engine in self.items():
|
||||
engine.name = name
|
||||
self._batch_size = 0
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
|
@ -211,6 +217,10 @@ class Engines(dict[str, Engine]):
|
|||
def micro_step(self):
|
||||
return self._micro_step
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
ret = {}
|
||||
for engine in self.values():
|
||||
|
@ -242,6 +252,9 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
for name, engine in self.items():
|
||||
if not engine._training:
|
||||
continue
|
||||
|
||||
save_dir = cfg.ckpt_dir / name
|
||||
try:
|
||||
engine.save_checkpoint(save_dir, tag=tag)
|
||||
|
@ -282,25 +295,19 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.hyperparameters.scheduler_type == "":
|
||||
self.set_lr(cfg.hyperparameters.learning_rate)
|
||||
|
||||
self._update_global_step()
|
||||
self._update_micro_step()
|
||||
self._update()
|
||||
|
||||
def set_lr(self, lr):
|
||||
for engine in self.values():
|
||||
if not engine._training:
|
||||
continue
|
||||
engine.set_lr(lr)
|
||||
|
||||
def _update_global_step(self):
|
||||
def _update(self):
|
||||
for engine in self.values():
|
||||
self._global_step = max(self._global_step, engine.global_step)
|
||||
|
||||
def _update_micro_step(self):
|
||||
for engine in self.values():
|
||||
self._micro_step = max(self._micro_step, engine.micro_step)
|
||||
|
||||
def train_batch_size(self):
|
||||
batch_size = 0
|
||||
for engine in self.values():
|
||||
batch_size = max(batch_size, engine.train_batch_size())
|
||||
self._batch_size = max(self._batch_size, engine.batch_size)
|
||||
|
||||
def eval(self):
|
||||
for engine in self.values():
|
||||
|
@ -325,8 +332,10 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.trainer.gc_mode == 'step':
|
||||
do_gc()
|
||||
|
||||
|
||||
for name, engine in self.items():
|
||||
if not engine._training:
|
||||
continue
|
||||
|
||||
device = engine.device
|
||||
|
||||
if cfg.trainer.gc_mode == 'substep':
|
||||
|
@ -424,9 +433,9 @@ class Engines(dict[str, Engine]):
|
|||
),
|
||||
)
|
||||
|
||||
self._update_global_step()
|
||||
self._update_micro_step()
|
||||
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
|
||||
self._update()
|
||||
|
||||
stats["batch_size"] = self.batch_size
|
||||
stats["elapsed_time"] = total_elapsed_time
|
||||
stats["wall_time"] = time.time()
|
||||
stats["global_step"] = self.global_step
|
||||
|
|
|
@ -52,6 +52,10 @@ class Engine(DeepSpeedEngine):
|
|||
p.requires_grad_(True)
|
||||
self._frozen_params.clear()
|
||||
|
||||
@property
|
||||
def _training(self):
|
||||
return self._cfg.training
|
||||
|
||||
@property
|
||||
def global_step(self):
|
||||
return self.global_steps
|
||||
|
@ -60,6 +64,10 @@ class Engine(DeepSpeedEngine):
|
|||
def micro_step(self):
|
||||
return self.micro_steps
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return cfg.hyperparameters.batch_size
|
||||
|
||||
def gather_attribute(self, *args, **kwargs):
|
||||
return gather_attribute(self.module, *args, **kwargs)
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ def train_feeder(engine, batch):
|
|||
return loss, stats
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_eval(engines, eval_name, dl):
|
||||
def run_eval(engines, disabled_engines, eval_name, dl):
|
||||
engines_stats = {
|
||||
'eval': eval_name
|
||||
}
|
||||
|
@ -51,11 +51,23 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
names = []
|
||||
for name, engine in engines.items():
|
||||
names.append(name)
|
||||
if name[:2] == "ar":
|
||||
AR = engine
|
||||
elif name[:3] == "nar":
|
||||
NAR = engine
|
||||
else:
|
||||
continue
|
||||
names.append(name)
|
||||
|
||||
# hotload the missing models
|
||||
for name, engine in disabled_engines.items():
|
||||
if AR is None and name[:2] == "ar":
|
||||
AR = engine
|
||||
elif NAR is None and name[:3] == "nar":
|
||||
NAR = engine
|
||||
else:
|
||||
continue
|
||||
names.append(name)
|
||||
|
||||
stats = defaultdict(list)
|
||||
stats['loss'] = []
|
||||
|
@ -148,13 +160,18 @@ def main():
|
|||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
def eval_fn(engines):
|
||||
disabled_engines = load_engines(invert=True) if cfg.evaluation.load_disabled_engines else {}
|
||||
try:
|
||||
run_eval(engines, "subtrain", subtrain_dl)
|
||||
run_eval(engines, "val", val_dl)
|
||||
run_eval(engines, disabled_engines, "subtrain", subtrain_dl)
|
||||
run_eval(engines, disabled_engines, "val", val_dl)
|
||||
except Exception as e:
|
||||
print("Error occurred while performing eval:", str(e))
|
||||
print(traceback.format_exc())
|
||||
|
||||
if len(disabled_engines.keys()):
|
||||
for name, engine in disabled_engines.items():
|
||||
engine = engine.to("cpu")
|
||||
del disabled_engines
|
||||
qnt.unload_model()
|
||||
do_gc()
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from .distributed import (
|
|||
local_leader_only,
|
||||
)
|
||||
|
||||
from ..engines import Engine, Engines, TrainFeeder, default_feeder
|
||||
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
|
||||
from ..models import get_models
|
||||
|
||||
from .utils import to_device, do_gc
|
||||
|
@ -36,36 +36,28 @@ from ..utils import wrapper as ml
|
|||
from ..data import get_phone_symmap # should decouple from this trainer script
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_engines: Engines
|
||||
_command: str
|
||||
|
||||
def get_global_step():
|
||||
try:
|
||||
return _engines.global_step
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_micro_step():
|
||||
try:
|
||||
return _engines.micro_step
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_cmd():
|
||||
try:
|
||||
return _command
|
||||
except:
|
||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
||||
|
||||
|
||||
get_iteration = get_global_step
|
||||
|
||||
def load_engines():
|
||||
def load_engines(invert=False):
|
||||
models = get_models(cfg.models.get())
|
||||
engines = dict()
|
||||
|
||||
for name in models:
|
||||
model = models[name]
|
||||
for name, model in models.items():
|
||||
# load only the models for training initially
|
||||
# loads disabled models at evaluation time (to load updated weights if training separately)
|
||||
# I'm sure there's a more elegant solution to this
|
||||
if cfg.evaluation.load_disabled_engines:
|
||||
if not invert and not model._cfg.training:
|
||||
continue
|
||||
if invert and model._cfg.training:
|
||||
continue
|
||||
# load only the models for training initially
|
||||
# if load_disabled_engines, then models not marked for training will be loaded but ignored
|
||||
# DeepSpeed has some weird quirks where loading an engine and moving it to CPU will have a memory leak or something
|
||||
# I recommend not using this pathway
|
||||
elif not cfg.trainer.load_disabled_engines:
|
||||
if model._cfg.training:
|
||||
continue
|
||||
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
@ -82,7 +74,11 @@ def load_engines():
|
|||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
if cfg.trainer.load_state_dict:
|
||||
if not model._cfg.training:
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
if cfg.trainer.load_state_dict or not model._cfg.training:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state = torch.load(load_path)
|
||||
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
||||
|
@ -109,18 +105,18 @@ def load_engines():
|
|||
|
||||
# copy weights from the dict into the old portion
|
||||
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
|
||||
# reuse additional levels, probably bad
|
||||
for n in range(o_resp_tokens, n_resp_tokens):
|
||||
model.resps_emb.weight.data[n] = model.resps_emb.weight.data[o_resp_tokens-1]
|
||||
# copy the full tensors back
|
||||
state['resps_emb.weight'] = model.resps_emb.weight
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
engines[name] = Engine(
|
||||
# use base engine because DeepSpeed memory leaks
|
||||
engines[name] = (Engine if model._cfg.training else _Engine)(
|
||||
#engines[name] = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
|
||||
_cfg=model._cfg,
|
||||
)
|
||||
|
||||
|
@ -130,6 +126,8 @@ def load_engines():
|
|||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
|
||||
do_gc()
|
||||
|
||||
return engines
|
||||
|
||||
class EvalFn(Protocol):
|
||||
|
|
|
@ -17,12 +17,21 @@ from torch import Tensor, nn
|
|||
from tqdm.auto import tqdm
|
||||
from typing import Callable, TypeVar, overload
|
||||
|
||||
try:
|
||||
from deepspeed.runtime.utils import empty_cache
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
def empty_cache():
|
||||
...
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def do_gc():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
empty_cache()
|
||||
|
||||
def flatten_dict(d):
|
||||
records = pd.json_normalize(d).to_dict(orient="records")
|
||||
return records[0] if records else {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user