APOLLO cringe (doesn't want to work with deepspeed)
This commit is contained in:
parent
cddf8ca814
commit
9a62e3b824
|
@ -621,7 +621,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
|
|
||||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||||
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
metadata = json_read( metadata_path )
|
try:
|
||||||
|
metadata = json_read( metadata_path )
|
||||||
|
except Exception as e:
|
||||||
|
return []
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||||
|
|
|
@ -132,6 +132,9 @@ def load_engines(training=True, **model_kwargs):
|
||||||
params['d_coef'] = params['lr']
|
params['d_coef'] = params['lr']
|
||||||
params['lr'] = 1.0
|
params['lr'] = 1.0
|
||||||
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||||
|
if backend == "deepspeed":
|
||||||
|
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
|
||||||
|
|
||||||
optimizer_class = ml.Apollo
|
optimizer_class = ml.Apollo
|
||||||
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||||
param_kwargs = {
|
param_kwargs = {
|
||||||
|
@ -146,8 +149,22 @@ def load_engines(training=True, **model_kwargs):
|
||||||
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
||||||
# and blank it so it doesn't update the main optimizer kwargs
|
# and blank it so it doesn't update the main optimizer kwargs
|
||||||
cfg.hyperparameters.optimizer_params = {}
|
cfg.hyperparameters.optimizer_params = {}
|
||||||
# settings are stored under params
|
|
||||||
params["params"] = [dict(params=params["params"], **param_kwargs)]
|
"""
|
||||||
|
params["params"] = [{'params': params["params"]} | param_kwargs]
|
||||||
|
"""
|
||||||
|
target_params = []
|
||||||
|
target_modules_list = ["attn", "mlp"]
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
if not (isinstance(module, torch.nn.Linear)):
|
||||||
|
continue
|
||||||
|
if not any(target_key in module_name for target_key in target_modules_list):
|
||||||
|
continue
|
||||||
|
target_params.append(module.weight)
|
||||||
|
|
||||||
|
param_ids = [id(p) for p in target_params]
|
||||||
|
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
|
||||||
|
params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs]
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||||
optimizer_class = ml.Adagrad
|
optimizer_class = ml.Adagrad
|
||||||
else:
|
else:
|
||||||
|
@ -207,6 +224,28 @@ def load_engines(training=True, **model_kwargs):
|
||||||
for k in erase:
|
for k in erase:
|
||||||
del state[k]
|
del state[k]
|
||||||
|
|
||||||
|
# converts an AR+NAR model into an AR+NAR-len model
|
||||||
|
"""
|
||||||
|
if True:
|
||||||
|
# move STT one over
|
||||||
|
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
||||||
|
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
||||||
|
# copy from AR:0:0 classifier
|
||||||
|
if True:
|
||||||
|
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
|
||||||
|
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
|
||||||
|
# copy from AR:0:0 embeddings
|
||||||
|
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
||||||
|
# remove
|
||||||
|
else:
|
||||||
|
if 'classifiers.proj.8.weight' in state:
|
||||||
|
del state['classifiers.proj.8.weight']
|
||||||
|
if 'classifiers.proj.8.bias' in state:
|
||||||
|
del state['classifiers.proj.8.bias']
|
||||||
|
if 'resps_emb.embeddings.8.weight' in state:
|
||||||
|
del state['resps_emb.embeddings.8.weight']
|
||||||
|
"""
|
||||||
|
|
||||||
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||||
if cfg.trainer.resize_modules:
|
if cfg.trainer.resize_modules:
|
||||||
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
|
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
|
||||||
|
@ -235,69 +274,49 @@ def load_engines(training=True, **model_kwargs):
|
||||||
continue
|
continue
|
||||||
state[k] = ml.resize_weight( state[k], tokens )
|
state[k] = ml.resize_weight( state[k], tokens )
|
||||||
|
|
||||||
"""
|
# stuff to inject new layers into an existing model train over (not recommended, it doesnt amount to anything)
|
||||||
if True:
|
"""
|
||||||
# move STT one over
|
if True:
|
||||||
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
remapped_dict = {}
|
||||||
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
remapped_indices = [
|
||||||
# copy from AR:0:0 classifier
|
(0, 1),
|
||||||
if False:
|
(1, 2),
|
||||||
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
|
(2, 3),
|
||||||
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
|
(3, 5),
|
||||||
# copy from AR:0:0 embeddings
|
(4, 6),
|
||||||
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
(5, 7),
|
||||||
# remove
|
(6, 9),
|
||||||
else:
|
(7, 10),
|
||||||
if 'classifiers.proj.8.weight' in state:
|
(8, 11),
|
||||||
del state['classifiers.proj.8.weight']
|
(9, 13),
|
||||||
if 'classifiers.proj.8.bias' in state:
|
(10, 14),
|
||||||
del state['classifiers.proj.8.bias']
|
(11, 15),
|
||||||
if 'resps_emb.embeddings.8.weight' in state:
|
]
|
||||||
del state['resps_emb.embeddings.8.weight']
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
for src, dst in remapped_indices:
|
||||||
if True:
|
remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"]
|
||||||
remapped_dict = {}
|
remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
||||||
remapped_indices = [
|
remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
||||||
(0, 1),
|
remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
||||||
(1, 2),
|
remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
||||||
(2, 3),
|
remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
||||||
(3, 5),
|
remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"]
|
||||||
(4, 6),
|
remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
||||||
(5, 7),
|
remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"]
|
||||||
(6, 9),
|
|
||||||
(7, 10),
|
|
||||||
(8, 11),
|
|
||||||
(9, 13),
|
|
||||||
(10, 14),
|
|
||||||
(11, 15),
|
|
||||||
]
|
|
||||||
|
|
||||||
for src, dst in remapped_indices:
|
del state[f"model.layers.{src}.input_layernorm.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"]
|
del state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
del state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
del state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
del state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
del state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
del state[f"model.layers.{src}.mlp.down_proj.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"]
|
del state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
del state[f"model.layers.{src}.mlp.up_proj.weight"]
|
||||||
remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"]
|
|
||||||
|
|
||||||
del state[f"model.layers.{src}.input_layernorm.weight"]
|
for k, v in remapped_dict.items():
|
||||||
del state[f"model.layers.{src}.self_attn.k_proj.weight"]
|
state[k] = v
|
||||||
del state[f"model.layers.{src}.self_attn.q_proj.weight"]
|
"""
|
||||||
del state[f"model.layers.{src}.self_attn.v_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.self_attn.o_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.post_attention_layernorm.weight"]
|
|
||||||
del state[f"model.layers.{src}.mlp.down_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.mlp.gate_proj.weight"]
|
|
||||||
del state[f"model.layers.{src}.mlp.up_proj.weight"]
|
|
||||||
|
|
||||||
for k, v in remapped_dict.items():
|
|
||||||
state[k] = v
|
|
||||||
"""
|
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
|
|
|
@ -934,6 +934,7 @@ def example_usage():
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
||||||
|
|
||||||
|
params = model.parameters()
|
||||||
if cfg.optimizations.dadaptation:
|
if cfg.optimizations.dadaptation:
|
||||||
# do not combine the two
|
# do not combine the two
|
||||||
if scheduler == "schedulefree":
|
if scheduler == "schedulefree":
|
||||||
|
@ -966,12 +967,28 @@ def example_usage():
|
||||||
learning_rate = 0.01
|
learning_rate = 0.01
|
||||||
|
|
||||||
optimizer = ml.Apollo
|
optimizer = ml.Apollo
|
||||||
|
|
||||||
|
"""
|
||||||
|
target_params = []
|
||||||
|
target_modules_list = ["attn", "mlp"]
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
if not (isinstance(module, torch.nn.Linear)):
|
||||||
|
continue
|
||||||
|
if not any(target_key in module_name for target_key in target_modules_list):
|
||||||
|
continue
|
||||||
|
target_params.append(module.weight)
|
||||||
|
|
||||||
|
param_ids = [id(p) for p in target_params]
|
||||||
|
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
|
||||||
|
params = [{'params': regular_params}, {'params': target_params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
|
||||||
|
"""
|
||||||
|
params = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||||
|
|
||||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||||
|
|
||||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
optimizer = optimizer(params, lr=learning_rate)
|
||||||
|
|
||||||
if scheduler == "schedulefree":
|
if scheduler == "schedulefree":
|
||||||
if isinstance(optimizer, ml.AdamW):
|
if isinstance(optimizer, ml.AdamW):
|
||||||
|
|
|
@ -326,8 +326,10 @@ class Apollo(Optimizer):
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
|
params_idx = 0
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
|
params_idx += 1
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
|
@ -339,6 +341,9 @@ class Apollo(Optimizer):
|
||||||
if "step" not in state:
|
if "step" not in state:
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
|
|
||||||
|
if "seed" not in state:
|
||||||
|
state["seed"] = params_idx
|
||||||
|
|
||||||
# GaLore Projection
|
# GaLore Projection
|
||||||
if "rank" in group:
|
if "rank" in group:
|
||||||
if "projector" not in state:
|
if "projector" not in state:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user