From 9a62e3b824d2637b8c6a150507f158230b39b178 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 12 Dec 2024 00:31:58 -0600 Subject: [PATCH] APOLLO cringe (doesn't want to work with deepspeed) --- vall_e/data.py | 5 +- vall_e/engines/__init__.py | 143 +++++++++++++++++++++---------------- vall_e/models/ar_nar.py | 19 ++++- vall_e/utils/ext/apollo.py | 5 ++ 4 files changed, 108 insertions(+), 64 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 0863b76..e88695c 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -621,7 +621,10 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): if cfg.dataset.use_metadata and metadata_path.exists(): #metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) - metadata = json_read( metadata_path ) + try: + metadata = json_read( metadata_path ) + except Exception as e: + return [] if len(metadata) == 0: return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate ) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index ae9bcc0..f404254 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -132,6 +132,9 @@ def load_engines(training=True, **model_kwargs): params['d_coef'] = params['lr'] params['lr'] = 1.0 elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]: + if backend == "deepspeed": + raise Exception("APOLLO currently does not play nicely with DeepSpeed.") + optimizer_class = ml.Apollo is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini" param_kwargs = { @@ -146,8 +149,22 @@ def load_engines(training=True, **model_kwargs): param_kwargs.update(cfg.hyperparameters.optimizer_params) # and blank it so it doesn't update the main optimizer kwargs cfg.hyperparameters.optimizer_params = {} - # settings are stored under params - params["params"] = [dict(params=params["params"], **param_kwargs)] + + """ + params["params"] = [{'params': params["params"]} | param_kwargs] + """ + target_params = [] + target_modules_list = ["attn", "mlp"] + for module_name, module in model.named_modules(): + if not (isinstance(module, torch.nn.Linear)): + continue + if not any(target_key in module_name for target_key in target_modules_list): + continue + target_params.append(module.weight) + + param_ids = [id(p) for p in target_params] + regular_params = [p for p in model.parameters() if id(p) not in param_ids] + params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs] elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad else: @@ -207,6 +224,28 @@ def load_engines(training=True, **model_kwargs): for k in erase: del state[k] + # converts an AR+NAR model into an AR+NAR-len model + """ + if True: + # move STT one over + state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone() + state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone() + # copy from AR:0:0 classifier + if True: + state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone() + state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone() + # copy from AR:0:0 embeddings + state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone() + # remove + else: + if 'classifiers.proj.8.weight' in state: + del state['classifiers.proj.8.weight'] + if 'classifiers.proj.8.bias' in state: + del state['classifiers.proj.8.bias'] + if 'resps_emb.embeddings.8.weight' in state: + del state['resps_emb.embeddings.8.weight'] + """ + # resize modules if I'm doing experiments and can't be assed to manually trim things if cfg.trainer.resize_modules: uses_stop_token = 1 if ("ar" in model.capabilities or "len" in model.capabilities) > 0 else 0 @@ -235,69 +274,49 @@ def load_engines(training=True, **model_kwargs): continue state[k] = ml.resize_weight( state[k], tokens ) - """ - if True: - # move STT one over - state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone() - state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone() - # copy from AR:0:0 classifier - if False: - state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone() - state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone() - # copy from AR:0:0 embeddings - state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone() - # remove - else: - if 'classifiers.proj.8.weight' in state: - del state['classifiers.proj.8.weight'] - if 'classifiers.proj.8.bias' in state: - del state['classifiers.proj.8.bias'] - if 'resps_emb.embeddings.8.weight' in state: - del state['resps_emb.embeddings.8.weight'] - """ + # stuff to inject new layers into an existing model train over (not recommended, it doesnt amount to anything) + """ + if True: + remapped_dict = {} + remapped_indices = [ + (0, 1), + (1, 2), + (2, 3), + (3, 5), + (4, 6), + (5, 7), + (6, 9), + (7, 10), + (8, 11), + (9, 13), + (10, 14), + (11, 15), + ] - """ - if True: - remapped_dict = {} - remapped_indices = [ - (0, 1), - (1, 2), - (2, 3), - (3, 5), - (4, 6), - (5, 7), - (6, 9), - (7, 10), - (8, 11), - (9, 13), - (10, 14), - (11, 15), - ] + for src, dst in remapped_indices: + remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"] + remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"] + remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"] + remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"] + remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"] + remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"] - for src, dst in remapped_indices: - remapped_dict[f"model.layers.{dst}.input_layernorm.weight"] = state[f"model.layers.{src}.input_layernorm.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.k_proj.weight"] = state[f"model.layers.{src}.self_attn.k_proj.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.q_proj.weight"] = state[f"model.layers.{src}.self_attn.q_proj.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.v_proj.weight"] = state[f"model.layers.{src}.self_attn.v_proj.weight"] - remapped_dict[f"model.layers.{dst}.self_attn.o_proj.weight"] = state[f"model.layers.{src}.self_attn.o_proj.weight"] - remapped_dict[f"model.layers.{dst}.post_attention_layernorm.weight"] = state[f"model.layers.{src}.post_attention_layernorm.weight"] - remapped_dict[f"model.layers.{dst}.mlp.down_proj.weight"] = state[f"model.layers.{src}.mlp.down_proj.weight"] - remapped_dict[f"model.layers.{dst}.mlp.gate_proj.weight"] = state[f"model.layers.{src}.mlp.gate_proj.weight"] - remapped_dict[f"model.layers.{dst}.mlp.up_proj.weight"] = state[f"model.layers.{src}.mlp.up_proj.weight"] + del state[f"model.layers.{src}.input_layernorm.weight"] + del state[f"model.layers.{src}.self_attn.k_proj.weight"] + del state[f"model.layers.{src}.self_attn.q_proj.weight"] + del state[f"model.layers.{src}.self_attn.v_proj.weight"] + del state[f"model.layers.{src}.self_attn.o_proj.weight"] + del state[f"model.layers.{src}.post_attention_layernorm.weight"] + del state[f"model.layers.{src}.mlp.down_proj.weight"] + del state[f"model.layers.{src}.mlp.gate_proj.weight"] + del state[f"model.layers.{src}.mlp.up_proj.weight"] - del state[f"model.layers.{src}.input_layernorm.weight"] - del state[f"model.layers.{src}.self_attn.k_proj.weight"] - del state[f"model.layers.{src}.self_attn.q_proj.weight"] - del state[f"model.layers.{src}.self_attn.v_proj.weight"] - del state[f"model.layers.{src}.self_attn.o_proj.weight"] - del state[f"model.layers.{src}.post_attention_layernorm.weight"] - del state[f"model.layers.{src}.mlp.down_proj.weight"] - del state[f"model.layers.{src}.mlp.gate_proj.weight"] - del state[f"model.layers.{src}.mlp.up_proj.weight"] - - for k, v in remapped_dict.items(): - state[k] = v - """ + for k, v in remapped_dict.items(): + state[k] = v + """ model.load_state_dict(state, strict=cfg.trainer.strict_loading) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 65b5cd1..a4a3a68 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -934,6 +934,7 @@ def example_usage(): scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None + params = model.parameters() if cfg.optimizations.dadaptation: # do not combine the two if scheduler == "schedulefree": @@ -966,12 +967,28 @@ def example_usage(): learning_rate = 0.01 optimizer = ml.Apollo + + """ + target_params = [] + target_modules_list = ["attn", "mlp"] + for module_name, module in model.named_modules(): + if not (isinstance(module, torch.nn.Linear)): + continue + if not any(target_key in module_name for target_key in target_modules_list): + continue + target_params.append(module.weight) + + param_ids = [id(p) for p in target_params] + regular_params = [p for p in model.parameters() if id(p) not in param_ids] + params = [{'params': regular_params}, {'params': target_params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}] + """ + params = [{'params': params, 'rank': 1, 'proj': 'random', 'scale_type': 'tensor', 'scale': 128,'update_proj_gap': 200, 'proj_type': 'std'}] else: raise ValueError(f"Unrecognized optimizer: {optimizer}") _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") - optimizer = optimizer(model.parameters(), lr=learning_rate) + optimizer = optimizer(params, lr=learning_rate) if scheduler == "schedulefree": if isinstance(optimizer, ml.AdamW): diff --git a/vall_e/utils/ext/apollo.py b/vall_e/utils/ext/apollo.py index 51b4983..ca4807e 100644 --- a/vall_e/utils/ext/apollo.py +++ b/vall_e/utils/ext/apollo.py @@ -326,8 +326,10 @@ class Apollo(Optimizer): if closure is not None: loss = closure() + params_idx = 0 for group in self.param_groups: for p in group["params"]: + params_idx += 1 if p.grad is None: continue grad = p.grad @@ -339,6 +341,9 @@ class Apollo(Optimizer): if "step" not in state: state["step"] = 0 + if "seed" not in state: + state["seed"] = params_idx + # GaLore Projection if "rank" in group: if "projector" not in state: