APOLLO cringe (doesn't want to work with deepspeed)
This commit is contained in:
parent
cddf8ca814
commit
9a62e3b824
|
@ -621,7 +621,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
|
||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||
#metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
try:
|
||||
metadata = json_read( metadata_path )
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||
|
|
|
@ -132,6 +132,9 @@ def load_engines(training=True, **model_kwargs):
|
|||
params['d_coef'] = params['lr']
|
||||
params['lr'] = 1.0
|
||||
elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]:
|
||||
if backend == "deepspeed":
|
||||
raise Exception("APOLLO currently does not play nicely with DeepSpeed.")
|
||||
|
||||
optimizer_class = ml.Apollo
|
||||
is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini"
|
||||
param_kwargs = {
|
||||
|
@ -146,8 +149,22 @@ def load_engines(training=True, **model_kwargs):
|
|||
param_kwargs.update(cfg.hyperparameters.optimizer_params)
|
||||
# and blank it so it doesn't update the main optimizer kwargs
|
||||
cfg.hyperparameters.optimizer_params = {}
|
||||
# settings are stored under params
|
||||
params["params"] = [dict(params=params["params"], **param_kwargs)]
|
||||
|
||||
"""
|
||||
params["params"] = [{'params': params["params"]} | param_kwargs]
|
||||
"""
|
||||
target_params = []
|
||||
target_modules_list = ["attn", "mlp"]
|
||||
for module_name, module in model.named_modules():
|
||||
if not (isinstance(module, torch.nn.Linear)):
|
||||
continue
|
||||
if not any(target_key in module_name for target_key in target_modules_list):
|
||||
continue
|
||||
target_params.append(module.weight)
|
||||
|
||||
param_ids = [id(p) for p in target_params]
|
||||
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
|
||||
params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs]
|
||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||
optimizer_class = ml.Adagrad
|
||||
else:
|
||||
|
@ -207,6 +224,28 @@ def load_engines(training=True, **model_kwargs):
|
|||
for k in erase:
|
||||
del state[k]
|
||||
|
||||
# converts an AR+NAR model into an AR+NAR-len model
|
||||
"""
|
||||
if True:
|
||||
# move STT one over
|
||||
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
||||
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
||||
# copy from AR:0:0 classifier
|
||||
if True:
|
||||
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
|
||||
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
|
||||
# copy from AR:0:0 embeddings
|
||||
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
||||
# remove
|
||||
else:
|
||||
if 'classifiers.proj.8.weight' in state:
|
||||
del state['classifiers.proj.8.weight']
|
||||
if 'classifiers.proj.8.bias' in state:
|
||||
del state['classifiers.proj.8.bias']
|
||||
if 'resps_emb.embeddings.8.weight' in state:
|
||||
del state['resps_emb.embeddings.8.weight']
|
||||
"""
|
||||
|
||||
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||
if cfg.trainer.resize_modules:
|
||||
uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0
|
||||
|
@ -235,27 +274,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
continue
|
||||
state[k] = ml.resize_weight( state[k], tokens )
|
||||
|
||||
"""
|
||||
if True:
|
||||
# move STT one over
|
||||
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
||||
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
||||
# copy from AR:0:0 classifier
|
||||
if False:
|
||||
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
|
||||
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
|
||||
# copy from AR:0:0 embeddings
|
||||
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
||||
# remove
|
||||
else:
|
||||
if 'classifiers.proj.8.weight' in state:
|
||||
del state['classifiers.proj.8.weight']
|
||||
if 'classifiers.proj.8.bias' in state:
|
||||
del state['classifiers.proj.8.bias']
|
||||
if 'resps_emb.embeddings.8.weight' in state:
|
||||
del state['resps_emb.embeddings.8.weight']
|
||||
"""
|
||||
|
||||
# stuff to inject new layers into an existing model train over (not recommended, it doesnt amount to anything)
|
||||
"""
|
||||
if True:
|
||||
remapped_dict = {}
|
||||
|
|
|
@ -934,6 +934,7 @@ def example_usage():
|
|||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
|
||||
|
||||
params = model.parameters()
|
||||
if cfg.optimizations.dadaptation:
|
||||
# do not combine the two
|
||||
if scheduler == "schedulefree":
|
||||
|
@ -966,12 +967,28 @@ def example_usage():
|
|||
learning_rate = 0.01
|
||||
|
||||
optimizer = ml.Apollo
|
||||
|
||||
"""
|
||||
target_params = []
|
||||
target_modules_list = ["attn", "mlp"]
|
||||
for module_name, module in model.named_modules():
|
||||
if not (isinstance(module, torch.nn.Linear)):
|
||||
continue
|
||||
if not any(target_key in module_name for target_key in target_modules_list):
|
||||
continue
|
||||
target_params.append(module.weight)
|
||||
|
||||
param_ids = [id(p) for p in target_params]
|
||||
regular_params = [p for p in model.parameters() if id(p) not in param_ids]
|
||||
params = [{'params': regular_params}, {'params': target_params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
|
||||
"""
|
||||
params = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}]
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
optimizer = optimizer(params, lr=learning_rate)
|
||||
|
||||
if scheduler == "schedulefree":
|
||||
if isinstance(optimizer, ml.AdamW):
|
||||
|
|
|
@ -326,8 +326,10 @@ class Apollo(Optimizer):
|
|||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
params_idx = 0
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
params_idx += 1
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
|
@ -339,6 +341,9 @@ class Apollo(Optimizer):
|
|||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
if "seed" not in state:
|
||||
state["seed"] = params_idx
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
|
|
Loading…
Reference in New Issue
Block a user