2023-08-04 01:26:36 +00:00
|
|
|
from ..config import cfg
|
|
|
|
|
2024-11-20 22:27:51 +00:00
|
|
|
from ..utils.distributed import fix_unset_envs, ddp_model, world_size
|
2023-08-05 03:22:15 +00:00
|
|
|
fix_unset_envs()
|
|
|
|
|
2023-08-04 01:26:36 +00:00
|
|
|
if cfg.trainer.backend == "deepspeed":
|
|
|
|
from .deepspeed import Engine
|
|
|
|
elif cfg.trainer.backend == "local":
|
|
|
|
from .base import Engine
|
|
|
|
|
2024-06-17 18:17:24 +00:00
|
|
|
from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
2023-10-09 20:24:04 +00:00
|
|
|
|
2024-07-16 23:23:13 +00:00
|
|
|
from ..models import get_models, get_model
|
2023-10-09 20:24:04 +00:00
|
|
|
from ..utils import wrapper as ml
|
2024-08-04 04:34:18 +00:00
|
|
|
from ..utils.io import torch_save, torch_load, pick_path
|
2024-07-23 01:47:24 +00:00
|
|
|
from ..models.lora import apply_lora, lora_load_state_dict
|
2024-06-17 18:55:37 +00:00
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
import torch
|
2023-12-26 03:20:32 +00:00
|
|
|
import re
|
2024-08-29 18:27:16 +00:00
|
|
|
import logging
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
2023-10-09 20:24:04 +00:00
|
|
|
|
|
|
|
deepspeed_available = False
|
|
|
|
try:
|
|
|
|
import deepspeed
|
|
|
|
deepspeed_available = True
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
|
2024-11-20 22:10:47 +00:00
|
|
|
try:
|
|
|
|
import wandb
|
|
|
|
except Exception as e:
|
|
|
|
_logger.warning(f'Failed to import wandb: {str(e)}')
|
|
|
|
wandb = None
|
|
|
|
|
2023-10-21 14:55:38 +00:00
|
|
|
from functools import cache
|
|
|
|
|
|
|
|
@cache
|
2024-08-27 00:33:51 +00:00
|
|
|
def load_engines(training=True, **model_kwargs):
|
|
|
|
models = get_models(cfg.models, training=training, **model_kwargs)
|
2023-10-09 20:24:04 +00:00
|
|
|
engines = dict()
|
|
|
|
|
|
|
|
for name, model in models.items():
|
2024-07-16 23:23:13 +00:00
|
|
|
state = None
|
|
|
|
stats = None
|
|
|
|
lora = None
|
|
|
|
|
2024-12-06 05:05:52 +00:00
|
|
|
inferencing = cfg.mode == "inferencing" or not model.config.training or not training or model.config.teacher
|
2024-07-23 01:47:24 +00:00
|
|
|
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
|
|
|
loads_state_dict = cfg.trainer.load_state_dict # or inferencing
|
2024-07-16 23:23:13 +00:00
|
|
|
|
|
|
|
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
|
|
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
2024-11-19 18:24:33 +00:00
|
|
|
load_path = pick_path( cfg.ckpt_dir / name / f"{cfg.weights_name}.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
2024-07-16 23:23:13 +00:00
|
|
|
|
|
|
|
# actually use the lora-specific checkpoint if available
|
2024-07-23 01:47:24 +00:00
|
|
|
if cfg.lora is not None:
|
|
|
|
checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest"
|
|
|
|
|
2024-07-31 03:15:56 +00:00
|
|
|
# to handle the issue of training with deepspeed, but inferencing with local
|
|
|
|
if checkpoint_path.exists() and backend == "local":
|
|
|
|
tag = open(checkpoint_path).read()
|
2024-08-04 04:34:18 +00:00
|
|
|
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
2024-07-16 23:23:13 +00:00
|
|
|
|
2024-10-26 03:15:15 +00:00
|
|
|
# if loaded using --model=
|
2024-10-26 05:13:10 +00:00
|
|
|
if model.config.path and model.config.path.exists():
|
|
|
|
load_path = model.config.path
|
2024-10-26 03:15:15 +00:00
|
|
|
|
2024-07-16 23:23:13 +00:00
|
|
|
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
2024-08-30 15:46:26 +00:00
|
|
|
_logger.warning(f"Checkpoint missing, but weights found: {load_path}")
|
2024-07-16 23:23:13 +00:00
|
|
|
loads_state_dict = True
|
|
|
|
|
|
|
|
# load state early
|
|
|
|
if loads_state_dict:
|
2024-08-04 04:15:20 +00:00
|
|
|
state = torch_load(load_path, device=cfg.device)
|
2024-07-16 23:23:13 +00:00
|
|
|
|
|
|
|
# check if config is defined in state, and re-initialize the model
|
2024-07-19 20:33:31 +00:00
|
|
|
if "config" in state and False:
|
2024-08-29 18:27:16 +00:00
|
|
|
_logger.warning("Model config definition in weights, re-loading...")
|
2024-07-16 23:23:13 +00:00
|
|
|
config_state = state["config"]
|
|
|
|
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
|
|
|
|
2024-06-17 18:55:37 +00:00
|
|
|
hyper_config = model.config
|
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
optimizer = None
|
|
|
|
lr_scheduler = None
|
|
|
|
|
|
|
|
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
|
|
|
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
2024-05-04 16:48:26 +00:00
|
|
|
ddp = cfg.trainer.ddp
|
2023-10-09 20:24:04 +00:00
|
|
|
|
2024-07-23 01:47:24 +00:00
|
|
|
engine_class = LocalEngine if backend == "local" else Engine
|
2023-10-09 20:24:04 +00:00
|
|
|
|
2024-07-16 23:23:13 +00:00
|
|
|
# apply model replacers
|
2024-05-10 01:28:20 +00:00
|
|
|
if cfg.optimizations.replace and cfg.optimizations.linear:
|
2024-03-02 01:20:10 +00:00
|
|
|
model.model = ml.replace_linear( model.model )
|
2024-05-10 01:28:20 +00:00
|
|
|
|
|
|
|
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
|
|
|
model.model = ml.replace_embedding( model.model )
|
2024-03-02 01:20:10 +00:00
|
|
|
|
2024-06-17 18:55:37 +00:00
|
|
|
for lora in cfg.loras:
|
2024-06-19 02:45:46 +00:00
|
|
|
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
2024-06-17 18:55:37 +00:00
|
|
|
|
2024-07-16 23:23:13 +00:00
|
|
|
if inferencing:
|
|
|
|
model.config.training = False
|
|
|
|
|
2024-06-19 02:45:46 +00:00
|
|
|
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
2023-10-09 20:24:04 +00:00
|
|
|
optimizer_class = None
|
2024-05-10 01:28:20 +00:00
|
|
|
scheduler_class = None
|
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
params = {
|
2024-12-11 02:13:21 +00:00
|
|
|
"params": [ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
2023-10-09 20:24:04 +00:00
|
|
|
"lr": cfg.hyperparameters.learning_rate,
|
|
|
|
}
|
2024-12-11 02:13:21 +00:00
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
|
|
|
params["betas"] = (0.9, 0.96)
|
|
|
|
params["eps"] = 1e-07
|
|
|
|
params["weight_decay"] = 0.01
|
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
# for dadaptation since it has Adam only
|
|
|
|
if ml.AdamW == ml.Adam:
|
|
|
|
params["decouple"] = True
|
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
optimizer_class = ml.AdamW
|
|
|
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
|
|
|
optimizer = ml.SGD
|
|
|
|
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
|
|
|
optimizer_class = ml.Prodigy
|
2024-02-01 03:48:36 +00:00
|
|
|
|
|
|
|
params['d_coef'] = params['lr']
|
|
|
|
params['lr'] = 1.0
|
2024-12-11 02:13:21 +00:00
|
|
|
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
|
|
|
optimizer_class = ml.Apollo
|
|
|
|
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
|
|
|
param_kwargs = {
|
|
|
|
"rank": 1 if is_mini else 256,
|
|
|
|
"proj": "random",
|
|
|
|
"scale_type": "tensor" if is_mini else "channel",
|
|
|
|
"scale": 128 if is_mini else 1,
|
|
|
|
"update_proj_gap": 200,
|
|
|
|
"proj_type": "std",
|
|
|
|
}
|
|
|
|
# grab any extra configs from the YAML
|
|
|
|
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
|
|
|
# and blank it so it doesn't update the main optimizer kwargs
|
|
|
|
cfg.hyperparameters.optimizer_params = {}
|
|
|
|
# settings are stored under params
|
|
|
|
params["params"] = [dict(params=params["params"], **param_kwargs)]
|
2024-04-10 03:04:01 +00:00
|
|
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
|
|
|
optimizer_class = ml.Adagrad
|
2023-10-09 20:24:04 +00:00
|
|
|
else:
|
|
|
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
|
|
|
|
|
|
|
params.update(cfg.hyperparameters.optimizer_params)
|
2024-12-11 02:13:21 +00:00
|
|
|
optimizer = optimizer_class(**params)
|
2023-10-09 20:24:04 +00:00
|
|
|
|
2024-05-10 01:28:20 +00:00
|
|
|
if cfg.hyperparameters.scheduler.lower() == "schedulefree":
|
|
|
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
|
|
|
scheduler_class = ml.schedulefree.AdamWScheduleFree
|
|
|
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
|
|
|
scheduler_class = ml.schedulefree.SGDScheduleFree
|
|
|
|
else:
|
|
|
|
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
|
|
|
|
|
|
|
|
optimizer = scheduler_class(
|
2024-06-06 14:48:43 +00:00
|
|
|
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
|
2024-05-10 03:18:49 +00:00
|
|
|
lr = params['lr'],
|
|
|
|
warmup_steps = cfg.hyperparameters.warmup_steps
|
2024-05-10 01:28:20 +00:00
|
|
|
)
|
|
|
|
|
2024-06-17 18:17:24 +00:00
|
|
|
"""
|
2023-10-09 20:24:04 +00:00
|
|
|
# set up our LR scheduler here
|
2024-06-17 18:17:24 +00:00
|
|
|
"""
|
2023-10-09 20:24:04 +00:00
|
|
|
|
|
|
|
if inferencing:
|
|
|
|
optimizer = None
|
|
|
|
lr_scheduler = None
|
|
|
|
|
2024-07-16 23:23:13 +00:00
|
|
|
# load state dict if requested / required
|
2023-10-09 20:24:04 +00:00
|
|
|
if loads_state_dict:
|
|
|
|
# state dict is not just the module, extract the extra trainer details
|
|
|
|
if "stats" in state:
|
|
|
|
stats = state["stats"]
|
|
|
|
|
2024-06-18 02:45:03 +00:00
|
|
|
# do not load stats if we're training a LoRA
|
2024-07-23 01:47:24 +00:00
|
|
|
if cfg.lora is not None or cfg.trainer.restart_step_count:
|
2024-06-18 02:45:03 +00:00
|
|
|
stats = None
|
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
if "module" in state:
|
|
|
|
state = state["module"]
|
|
|
|
|
2023-12-26 03:20:32 +00:00
|
|
|
# maintain compat if I change variable names
|
|
|
|
insert = {}
|
|
|
|
erase = []
|
|
|
|
|
|
|
|
for k in state.keys():
|
|
|
|
key = re.sub(r'^retnet\.', "model.", k)
|
|
|
|
if k != key:
|
|
|
|
insert[key] = state[k]
|
|
|
|
erase.append(k)
|
|
|
|
|
|
|
|
for k in insert.keys():
|
|
|
|
state[k] = insert[k]
|
|
|
|
|
|
|
|
for k in erase:
|
|
|
|
del state[k]
|
|
|
|
|
2024-08-01 01:35:09 +00:00
|
|
|
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
|
|
|
if cfg.trainer.resize_modules:
|
2024-11-11 02:37:50 +00:00
|
|
|
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
|
2024-08-01 01:35:09 +00:00
|
|
|
keys = [
|
|
|
|
("text_emb.weight", model.config.text_tokens ),
|
2024-09-06 02:42:59 +00:00
|
|
|
("tasks_emb.weight", model.config.tasks ),
|
2024-09-19 02:40:57 +00:00
|
|
|
("langs_emb.weight", model.config.langs ),
|
2024-11-10 18:19:48 +00:00
|
|
|
("rvq_l_emb.weight", model.config.resp_levels ),
|
2024-08-01 01:35:09 +00:00
|
|
|
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
|
|
|
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
2024-08-03 01:28:49 +00:00
|
|
|
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),
|
|
|
|
("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ),
|
2024-08-01 01:35:09 +00:00
|
|
|
]
|
2024-12-07 18:31:54 +00:00
|
|
|
|
|
|
|
# correcting an oversight
|
|
|
|
if model.config.experimental.split_classifiers and "len" in model.capabilities:
|
|
|
|
len_idx, nar_0_idx = model.classifiers.indices(["len", "NAR:0:0"])
|
|
|
|
keys.append((f"classifiers.proj.{len_idx}.weight", 11))
|
|
|
|
keys.append((f"classifiers.proj.{len_idx}.bias", 11))
|
|
|
|
|
|
|
|
keys.append((f"classifiers.proj.{nar_0_idx}.weight", 1024))
|
|
|
|
keys.append((f"classifiers.proj.{nar_0_idx}.bias", 1024))
|
|
|
|
|
2024-08-01 01:35:09 +00:00
|
|
|
for k, tokens in keys:
|
2024-08-04 05:14:49 +00:00
|
|
|
if k not in state:
|
|
|
|
continue
|
2024-08-01 01:35:09 +00:00
|
|
|
state[k] = ml.resize_weight( state[k], tokens )
|
2024-06-06 02:02:05 +00:00
|
|
|
|
2024-11-13 17:38:58 +00:00
|
|
|
"""
|
2024-11-14 13:34:22 +00:00
|
|
|
if True:
|
|
|
|
# move STT one over
|
2024-11-13 17:38:58 +00:00
|
|
|
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
|
|
|
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
2024-11-14 13:34:22 +00:00
|
|
|
# copy from AR:0:0 classifier
|
|
|
|
if False:
|
|
|
|
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
|
|
|
|
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
|
|
|
|
# copy from AR:0:0 embeddings
|
|
|
|
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
|
|
|
# remove
|
|
|
|
else:
|
|
|
|
if 'classifiers.proj.8.weight' in state:
|
|
|
|
del state['classifiers.proj.8.weight']
|
|
|
|
if 'classifiers.proj.8.bias' in state:
|
|
|
|
del state['classifiers.proj.8.bias']
|
|
|
|
if 'resps_emb.embeddings.8.weight' in state:
|
|
|
|
del state['resps_emb.embeddings.8.weight']
|
2024-11-13 17:38:58 +00:00
|
|
|
"""
|
|
|
|
|
2024-11-21 21:07:46 +00:00
|
|
|
"""
|
|
|
|
if True:
|
|
|
|
remapped_dict = {}
|
|
|
|
remapped_indices = [
|
|
|
|
(0, 1),
|
|
|
|
(1, 2),
|
|
|
|
(2, 3),
|
|
|
|
(3, 5),
|
|
|
|
(4, 6),
|
|
|
|
(5, 7),
|
|
|
|
(6, 9),
|
|
|
|
(7, 10),
|
|
|
|
(8, 11),
|
|
|
|
(9, 13),
|
|
|
|
(10, 14),
|
|
|
|
(11, 15),
|
|
|
|
]
|
|
|
|
|
|
|
|
for src, dst in remapped_indices:
|
|
|
|
remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
|
|
|
remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"]
|
|
|
|
|
|
|
|
del state[f"model.layers.{src}.input_layernorm.weight"]
|
|
|
|
del state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
|
|
|
del state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
|
|
|
del state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
|
|
|
del state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
|
|
|
del state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
|
|
|
del state[f"model.layers.{src}.mlp.down_proj.weight"]
|
|
|
|
del state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
|
|
|
del state[f"model.layers.{src}.mlp.up_proj.weight"]
|
|
|
|
|
|
|
|
for k, v in remapped_dict.items():
|
|
|
|
state[k] = v
|
|
|
|
"""
|
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
|
|
|
|
2024-07-23 01:47:24 +00:00
|
|
|
# load lora weights if exists
|
|
|
|
if cfg.lora is not None:
|
2024-10-26 05:13:10 +00:00
|
|
|
if cfg.lora.path:
|
|
|
|
lora_path = cfg.lora.path
|
|
|
|
else:
|
|
|
|
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
|
|
|
|
2024-07-23 01:47:24 +00:00
|
|
|
if lora_path.exists():
|
2024-08-30 15:46:26 +00:00
|
|
|
_logger.info( f"Loaded LoRA state dict: {lora_path}" )
|
2024-07-23 01:47:24 +00:00
|
|
|
|
2024-08-04 04:15:20 +00:00
|
|
|
state = torch_load(lora_path, device=cfg.device)
|
2024-07-23 01:47:24 +00:00
|
|
|
state = state['lora' if 'lora' in state else 'module']
|
|
|
|
lora_load_state_dict( model, state )
|
2024-06-19 02:45:46 +00:00
|
|
|
|
2024-05-04 16:48:26 +00:00
|
|
|
# wrap if DDP is requested
|
|
|
|
if ddp:
|
|
|
|
model = ddp_model(model)
|
2024-08-04 03:10:21 +00:00
|
|
|
# wrap optimization class
|
|
|
|
elif cfg.optimizations.compile:
|
|
|
|
model = ml.compile_model(model, backend=cfg.optimizations.compile)
|
2023-10-09 20:24:04 +00:00
|
|
|
# deepspeed inferencing
|
2024-05-04 16:48:26 +00:00
|
|
|
elif backend == "local" and inferencing and deepspeed_available and cfg.trainer.deepspeed.inferencing: #and sys.platform.startswith("win"):
|
2024-06-17 18:17:24 +00:00
|
|
|
engine_class = LocalEngine
|
2023-10-09 20:24:04 +00:00
|
|
|
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
|
|
|
|
|
|
|
# use base engine if requested
|
|
|
|
engines[name] = engine_class(
|
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
lr_scheduler=lr_scheduler,
|
|
|
|
|
2024-06-04 02:28:49 +00:00
|
|
|
hyper_config=hyper_config,
|
2023-10-09 20:24:04 +00:00
|
|
|
stats=stats
|
|
|
|
)
|
2024-05-04 16:48:26 +00:00
|
|
|
|
2023-10-09 20:24:04 +00:00
|
|
|
|
|
|
|
engines = Engines(engines)
|
|
|
|
engines.setup()
|
|
|
|
|
2024-07-23 01:47:24 +00:00
|
|
|
# this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not
|
2023-10-09 20:24:04 +00:00
|
|
|
if not cfg.trainer.load_state_dict:
|
2024-07-23 01:47:24 +00:00
|
|
|
engines.load_checkpoint(training=not inferencing)
|
2023-10-09 20:24:04 +00:00
|
|
|
|
|
|
|
# freeze requested params
|
|
|
|
for name, engine in engines.items():
|
|
|
|
engine.freeze(freeze_all=False)
|
|
|
|
|
2024-08-02 01:12:06 +00:00
|
|
|
# split models over requested devices
|
|
|
|
if cfg.optimizations.model_offloading:
|
|
|
|
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
|
|
|
|
|
2024-12-06 05:05:52 +00:00
|
|
|
# set to train/eval
|
|
|
|
if engine.hyper_config.training:
|
|
|
|
engine.module.train()
|
|
|
|
else:
|
|
|
|
engine.module.eval()
|
|
|
|
|
2024-11-20 22:10:47 +00:00
|
|
|
# setup wandb
|
|
|
|
if engine._training and cfg.trainer.wandb and wandb is not None:
|
2024-11-20 22:27:51 +00:00
|
|
|
key_name = name
|
|
|
|
kwargs = {}
|
|
|
|
if cfg.lora is not None:
|
|
|
|
key_name = cfg.lora.full_name
|
|
|
|
|
|
|
|
if world_size() > 1:
|
|
|
|
kwargs["group"] = "DDP"
|
|
|
|
|
|
|
|
engine.wandb = wandb.init(project=key_name, **kwargs)
|
2024-11-20 22:10:47 +00:00
|
|
|
engine.wandb.watch(engine.module)
|
|
|
|
else:
|
|
|
|
engine.wandb = None
|
|
|
|
|
2024-08-01 01:35:09 +00:00
|
|
|
return engines
|