added ability to mark models as disabled for training, and hotloading them for eval/validation (useful if training only one model, or training a model per GPU)
This commit is contained in:
parent
165a1154e0
commit
87c4bfedba
|
@ -161,6 +161,7 @@ class Model:
|
||||||
prom_levels: int = 8
|
prom_levels: int = 8
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "transformer"
|
||||||
|
training: bool = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self):
|
def scale(self):
|
||||||
|
@ -215,10 +216,11 @@ class Model:
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Models:
|
class Models:
|
||||||
_max_levels: int = 0
|
_max_levels: int = 0
|
||||||
|
_prom_levels: int = 1
|
||||||
|
|
||||||
_models: list[Model] = field(default_factory=lambda: [
|
_models: list[Model] = field(default_factory=lambda: [
|
||||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8),
|
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True),
|
||||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8),
|
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
@ -241,7 +243,7 @@ class Models:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prom_levels(self):
|
def prom_levels(self):
|
||||||
prom_levels = 1
|
prom_levels = self._prom_levels
|
||||||
for model in self._models:
|
for model in self._models:
|
||||||
prom_levels = max(prom_levels, model.prom_levels)
|
prom_levels = max(prom_levels, model.prom_levels)
|
||||||
return prom_levels
|
return prom_levels
|
||||||
|
@ -279,6 +281,8 @@ class Evaluation:
|
||||||
ar_temperature: float = 1.0
|
ar_temperature: float = 1.0
|
||||||
nar_temperature: float = 0.2
|
nar_temperature: float = 0.2
|
||||||
|
|
||||||
|
load_disabled_engines: bool = True
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class DeepSpeed:
|
class DeepSpeed:
|
||||||
zero_optimization_level: int = 0
|
zero_optimization_level: int = 0
|
||||||
|
@ -407,6 +411,8 @@ class Trainer:
|
||||||
aggressive_optimizations: bool = False
|
aggressive_optimizations: bool = False
|
||||||
check_for_oom: bool = True
|
check_for_oom: bool = True
|
||||||
|
|
||||||
|
load_disabled_engines: bool = False
|
||||||
|
|
||||||
gc_mode: str | None = None
|
gc_mode: str | None = None
|
||||||
|
|
||||||
weight_dtype: str = "float16"
|
weight_dtype: str = "float16"
|
||||||
|
|
|
@ -459,7 +459,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
path=path,
|
path=Path(path),
|
||||||
spkr_name=spkr_name,
|
spkr_name=spkr_name,
|
||||||
spkr_id=spkr_id,
|
spkr_id=spkr_id,
|
||||||
task=task,
|
task=task,
|
||||||
|
|
|
@ -8,4 +8,4 @@ if cfg.trainer.backend == "deepspeed":
|
||||||
elif cfg.trainer.backend == "local":
|
elif cfg.trainer.backend == "local":
|
||||||
from .base import Engine
|
from .base import Engine
|
||||||
|
|
||||||
from .base import Engines, TrainFeeder, default_feeder
|
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
|
@ -74,6 +74,10 @@ class Engine():
|
||||||
p.requires_grad_(True)
|
p.requires_grad_(True)
|
||||||
self._frozen_params.clear()
|
self._frozen_params.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _training(self):
|
||||||
|
return self._cfg.training
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def global_step(self):
|
def global_step(self):
|
||||||
return self.global_steps
|
return self.global_steps
|
||||||
|
@ -82,7 +86,8 @@ class Engine():
|
||||||
def micro_step(self):
|
def micro_step(self):
|
||||||
return self.micro_steps
|
return self.micro_steps
|
||||||
|
|
||||||
def train_batch_size(self):
|
@property
|
||||||
|
def batch_size(self):
|
||||||
return cfg.hyperparameters.batch_size
|
return cfg.hyperparameters.batch_size
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
def gather_attribute(self, *args, **kwargs):
|
||||||
|
@ -137,7 +142,10 @@ class Engine():
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
self.module = self.module.to(*args, **kwargs)
|
self.module = self.module.to(*args, **kwargs)
|
||||||
return self.module
|
if self.optimizer:
|
||||||
|
self.optimizer = self.optimizer.to(*args, **kwargs)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
@ -199,9 +207,7 @@ class Engines(dict[str, Engine]):
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self._global_step = 0
|
self._global_step = 0
|
||||||
self._micro_step = 0
|
self._micro_step = 0
|
||||||
|
self._batch_size = 0
|
||||||
for name, engine in self.items():
|
|
||||||
engine.name = name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def global_step(self):
|
def global_step(self):
|
||||||
|
@ -211,6 +217,10 @@ class Engines(dict[str, Engine]):
|
||||||
def micro_step(self):
|
def micro_step(self):
|
||||||
return self._micro_step
|
return self._micro_step
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
def gather_attribute(self, *args, **kwargs):
|
||||||
ret = {}
|
ret = {}
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
|
@ -242,6 +252,9 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
cfg.ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
|
if not engine._training:
|
||||||
|
continue
|
||||||
|
|
||||||
save_dir = cfg.ckpt_dir / name
|
save_dir = cfg.ckpt_dir / name
|
||||||
try:
|
try:
|
||||||
engine.save_checkpoint(save_dir, tag=tag)
|
engine.save_checkpoint(save_dir, tag=tag)
|
||||||
|
@ -282,25 +295,19 @@ class Engines(dict[str, Engine]):
|
||||||
if cfg.hyperparameters.scheduler_type == "":
|
if cfg.hyperparameters.scheduler_type == "":
|
||||||
self.set_lr(cfg.hyperparameters.learning_rate)
|
self.set_lr(cfg.hyperparameters.learning_rate)
|
||||||
|
|
||||||
self._update_global_step()
|
self._update()
|
||||||
self._update_micro_step()
|
|
||||||
|
|
||||||
def set_lr(self, lr):
|
def set_lr(self, lr):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
|
if not engine._training:
|
||||||
|
continue
|
||||||
engine.set_lr(lr)
|
engine.set_lr(lr)
|
||||||
|
|
||||||
def _update_global_step(self):
|
def _update(self):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
self._global_step = max(self._global_step, engine.global_step)
|
self._global_step = max(self._global_step, engine.global_step)
|
||||||
|
|
||||||
def _update_micro_step(self):
|
|
||||||
for engine in self.values():
|
|
||||||
self._micro_step = max(self._micro_step, engine.micro_step)
|
self._micro_step = max(self._micro_step, engine.micro_step)
|
||||||
|
self._batch_size = max(self._batch_size, engine.batch_size)
|
||||||
def train_batch_size(self):
|
|
||||||
batch_size = 0
|
|
||||||
for engine in self.values():
|
|
||||||
batch_size = max(batch_size, engine.train_batch_size())
|
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
|
@ -325,8 +332,10 @@ class Engines(dict[str, Engine]):
|
||||||
if cfg.trainer.gc_mode == 'step':
|
if cfg.trainer.gc_mode == 'step':
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
|
if not engine._training:
|
||||||
|
continue
|
||||||
|
|
||||||
device = engine.device
|
device = engine.device
|
||||||
|
|
||||||
if cfg.trainer.gc_mode == 'substep':
|
if cfg.trainer.gc_mode == 'substep':
|
||||||
|
@ -424,9 +433,9 @@ class Engines(dict[str, Engine]):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._update_global_step()
|
self._update()
|
||||||
self._update_micro_step()
|
|
||||||
stats["batch_size"] = self.train_batch_size() # len(batch["text"])
|
stats["batch_size"] = self.batch_size
|
||||||
stats["elapsed_time"] = total_elapsed_time
|
stats["elapsed_time"] = total_elapsed_time
|
||||||
stats["wall_time"] = time.time()
|
stats["wall_time"] = time.time()
|
||||||
stats["global_step"] = self.global_step
|
stats["global_step"] = self.global_step
|
||||||
|
|
|
@ -51,6 +51,10 @@ class Engine(DeepSpeedEngine):
|
||||||
for p in self._frozen_params:
|
for p in self._frozen_params:
|
||||||
p.requires_grad_(True)
|
p.requires_grad_(True)
|
||||||
self._frozen_params.clear()
|
self._frozen_params.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _training(self):
|
||||||
|
return self._cfg.training
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def global_step(self):
|
def global_step(self):
|
||||||
|
@ -60,6 +64,10 @@ class Engine(DeepSpeedEngine):
|
||||||
def micro_step(self):
|
def micro_step(self):
|
||||||
return self.micro_steps
|
return self.micro_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
return cfg.hyperparameters.batch_size
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
def gather_attribute(self, *args, **kwargs):
|
||||||
return gather_attribute(self.module, *args, **kwargs)
|
return gather_attribute(self.module, *args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ def train_feeder(engine, batch):
|
||||||
return loss, stats
|
return loss, stats
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_eval(engines, eval_name, dl):
|
def run_eval(engines, disabled_engines, eval_name, dl):
|
||||||
engines_stats = {
|
engines_stats = {
|
||||||
'eval': eval_name
|
'eval': eval_name
|
||||||
}
|
}
|
||||||
|
@ -51,11 +51,23 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
names.append(name)
|
|
||||||
if name[:2] == "ar":
|
if name[:2] == "ar":
|
||||||
AR = engine
|
AR = engine
|
||||||
elif name[:3] == "nar":
|
elif name[:3] == "nar":
|
||||||
NAR = engine
|
NAR = engine
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
names.append(name)
|
||||||
|
|
||||||
|
# hotload the missing models
|
||||||
|
for name, engine in disabled_engines.items():
|
||||||
|
if AR is None and name[:2] == "ar":
|
||||||
|
AR = engine
|
||||||
|
elif NAR is None and name[:3] == "nar":
|
||||||
|
NAR = engine
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
names.append(name)
|
||||||
|
|
||||||
stats = defaultdict(list)
|
stats = defaultdict(list)
|
||||||
stats['loss'] = []
|
stats['loss'] = []
|
||||||
|
@ -148,13 +160,18 @@ def main():
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
|
||||||
def eval_fn(engines):
|
def eval_fn(engines):
|
||||||
|
disabled_engines = load_engines(invert=True) if cfg.evaluation.load_disabled_engines else {}
|
||||||
try:
|
try:
|
||||||
run_eval(engines, "subtrain", subtrain_dl)
|
run_eval(engines, disabled_engines, "subtrain", subtrain_dl)
|
||||||
run_eval(engines, "val", val_dl)
|
run_eval(engines, disabled_engines, "val", val_dl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error occurred while performing eval:", str(e))
|
print("Error occurred while performing eval:", str(e))
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if len(disabled_engines.keys()):
|
||||||
|
for name, engine in disabled_engines.items():
|
||||||
|
engine = engine.to("cpu")
|
||||||
|
del disabled_engines
|
||||||
qnt.unload_model()
|
qnt.unload_model()
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ from .distributed import (
|
||||||
local_leader_only,
|
local_leader_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..engines import Engine, Engines, TrainFeeder, default_feeder
|
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
|
||||||
from ..models import get_models
|
from ..models import get_models
|
||||||
|
|
||||||
from .utils import to_device, do_gc
|
from .utils import to_device, do_gc
|
||||||
|
@ -36,36 +36,28 @@ from ..utils import wrapper as ml
|
||||||
from ..data import get_phone_symmap # should decouple from this trainer script
|
from ..data import get_phone_symmap # should decouple from this trainer script
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
_engines: Engines
|
|
||||||
_command: str
|
_command: str
|
||||||
|
|
||||||
def get_global_step():
|
def load_engines(invert=False):
|
||||||
try:
|
|
||||||
return _engines.global_step
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_micro_step():
|
|
||||||
try:
|
|
||||||
return _engines.micro_step
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_cmd():
|
|
||||||
try:
|
|
||||||
return _command
|
|
||||||
except:
|
|
||||||
raise RuntimeError("Trainer has not been setup. Have you called trainer.train?")
|
|
||||||
|
|
||||||
|
|
||||||
get_iteration = get_global_step
|
|
||||||
|
|
||||||
def load_engines():
|
|
||||||
models = get_models(cfg.models.get())
|
models = get_models(cfg.models.get())
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
|
||||||
for name in models:
|
for name, model in models.items():
|
||||||
model = models[name]
|
# load only the models for training initially
|
||||||
|
# loads disabled models at evaluation time (to load updated weights if training separately)
|
||||||
|
# I'm sure there's a more elegant solution to this
|
||||||
|
if cfg.evaluation.load_disabled_engines:
|
||||||
|
if not invert and not model._cfg.training:
|
||||||
|
continue
|
||||||
|
if invert and model._cfg.training:
|
||||||
|
continue
|
||||||
|
# load only the models for training initially
|
||||||
|
# if load_disabled_engines, then models not marked for training will be loaded but ignored
|
||||||
|
# DeepSpeed has some weird quirks where loading an engine and moving it to CPU will have a memory leak or something
|
||||||
|
# I recommend not using this pathway
|
||||||
|
elif not cfg.trainer.load_disabled_engines:
|
||||||
|
if model._cfg.training:
|
||||||
|
continue
|
||||||
|
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
@ -82,7 +74,11 @@ def load_engines():
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.trainer.load_state_dict:
|
if not model._cfg.training:
|
||||||
|
optimizer = None
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
|
if cfg.trainer.load_state_dict or not model._cfg.training:
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
state = torch.load(load_path)
|
state = torch.load(load_path)
|
||||||
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
||||||
|
@ -109,18 +105,18 @@ def load_engines():
|
||||||
|
|
||||||
# copy weights from the dict into the old portion
|
# copy weights from the dict into the old portion
|
||||||
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
|
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
|
||||||
# reuse additional levels, probably bad
|
|
||||||
for n in range(o_resp_tokens, n_resp_tokens):
|
|
||||||
model.resps_emb.weight.data[n] = model.resps_emb.weight.data[o_resp_tokens-1]
|
|
||||||
# copy the full tensors back
|
# copy the full tensors back
|
||||||
state['resps_emb.weight'] = model.resps_emb.weight
|
state['resps_emb.weight'] = model.resps_emb.weight
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
engines[name] = Engine(
|
# use base engine because DeepSpeed memory leaks
|
||||||
|
engines[name] = (Engine if model._cfg.training else _Engine)(
|
||||||
|
#engines[name] = Engine(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
|
|
||||||
_cfg=model._cfg,
|
_cfg=model._cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -130,6 +126,8 @@ def load_engines():
|
||||||
if not cfg.trainer.load_state_dict:
|
if not cfg.trainer.load_state_dict:
|
||||||
engines.load_checkpoint()
|
engines.load_checkpoint()
|
||||||
|
|
||||||
|
do_gc()
|
||||||
|
|
||||||
return engines
|
return engines
|
||||||
|
|
||||||
class EvalFn(Protocol):
|
class EvalFn(Protocol):
|
||||||
|
|
|
@ -17,12 +17,21 @@ from torch import Tensor, nn
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from typing import Callable, TypeVar, overload
|
from typing import Callable, TypeVar, overload
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deepspeed.runtime.utils import empty_cache
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
def empty_cache():
|
||||||
|
...
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
def do_gc():
|
def do_gc():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
def flatten_dict(d):
|
def flatten_dict(d):
|
||||||
records = pd.json_normalize(d).to_dict(orient="records")
|
records = pd.json_normalize(d).to_dict(orient="records")
|
||||||
return records[0] if records else {}
|
return records[0] if records else {}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user