From 09804ecc1601c7c40cb6b5cc9db703930bb1c3df Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 13 Dec 2024 23:03:52 -0600 Subject: [PATCH] APOLLO tweaks to make it work with deepspeed --- vall_e/config.py | 2 +- vall_e/engines/__init__.py | 30 ++--------- vall_e/engines/base.py | 20 ++++--- vall_e/utils/ext/apollo.py | 107 ++++++++++++++++++------------------- 4 files changed, 67 insertions(+), 92 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 9a600d5..e0feb4d 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -520,7 +520,7 @@ class DeepSpeed: use_compression_training: bool = False # cope compression_bits: int = 8 # cope inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead - + optimizer: bool = True # use DeepSpeed optimizer wrapper amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently) loss_scale_window: int = 100 diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 4caf299..0c0d99d 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -132,41 +132,17 @@ def load_engines(training=True, **model_kwargs): params['d_coef'] = params['lr'] params['lr'] = 1.0 elif cfg.hyperparameters.optimizer.lower() in ["apollo","apollo-mini"]: - """ - if backend == "deepspeed": - raise Exception("APOLLO currently does not play nicely with DeepSpeed.") - """ - optimizer_class = ml.Apollo is_mini = cfg.hyperparameters.optimizer.lower() == "apollo-mini" - param_kwargs = { + + params.update({ "rank": 1 if is_mini else 256, "proj": "random", "scale_type": "tensor" if is_mini else "channel", "scale": 128 if is_mini else 1, "update_proj_gap": 200, "proj_type": "std", - } - # grab any extra configs from the YAML - param_kwargs.update(cfg.hyperparameters.optimizer_params) - # and blank it so it doesn't update the main optimizer kwargs - cfg.hyperparameters.optimizer_params = {} - - """ - params["params"] = [{'params': params["params"]} | param_kwargs] - """ - target_params = [] - target_modules_list = ["attn", "mlp"] - for module_name, module in model.named_modules(): - if not (isinstance(module, torch.nn.Linear)): - continue - if not any(target_key in module_name for target_key in target_modules_list): - continue - target_params.append(module.weight) - - param_ids = [id(p) for p in target_params] - regular_params = [p for p in model.parameters() if id(p) not in param_ids] - params["params"] = [{'params': regular_params}, {'params': target_params} | param_kwargs] + }) elif cfg.hyperparameters.optimizer.lower() == "adagrad": optimizer_class = ml.Adagrad else: diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 053eb0f..537a364 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -406,10 +406,14 @@ class Engines(dict[str, Engine]): if cfg.lora is not None: save_dir = cfg.ckpt_dir / cfg.lora.full_name + engine.save_checkpoint(save_dir, tag=tag) + + """ try: engine.save_checkpoint(save_dir, tag=tag) except Exception as e: _logger.warning(f'Failed to save checkpoint for engine {name}: {str(e)}') + """ # might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None] if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader(): @@ -515,11 +519,11 @@ class Engines(dict[str, Engine]): start_time = time.time() batch = to_device(batch, device) - n_ooms = torch.zeros([], device=device) if not cfg.trainer.check_for_oom: res = feeder( engine=engine, batch=batch, teacher=teacher ) else: + forward_ooms = torch.zeros([], device=device) try: res = feeder( engine=engine, batch=batch, teacher=teacher ) except RuntimeError as e: @@ -529,12 +533,12 @@ class Engines(dict[str, Engine]): self.save_checkpoint() raise e - n_ooms += 1 + forward_ooms += 1 if world_size() > 1: - all_reduce(n_ooms) + all_reduce(forward_ooms) - if n_ooms.item() > 0: + if forward_ooms.item() > 0: continue """ self.save_checkpoint() @@ -554,7 +558,7 @@ class Engines(dict[str, Engine]): if not cfg.trainer.check_for_oom: engine.backward(loss) else: - # to-do: properly handle when one GPU throws an OOM because it just halts despite doing a gather/reduce + backward_ooms = torch.zeros([], device=device) try: engine.backward(loss) except RuntimeError as e: @@ -564,12 +568,12 @@ class Engines(dict[str, Engine]): self.save_checkpoint() raise e - n_ooms += 1 + backward_ooms += 1 if world_size() > 1: - all_reduce(n_ooms) + all_reduce(backward_ooms) - if n_ooms.item() > 0: + if backward_ooms.item() > 0: self.save_checkpoint() raise RuntimeError("Out of memory during backwards pass!") diff --git a/vall_e/utils/ext/apollo.py b/vall_e/utils/ext/apollo.py index ca4807e..ae08e87 100644 --- a/vall_e/utils/ext/apollo.py +++ b/vall_e/utils/ext/apollo.py @@ -28,7 +28,7 @@ class GaLoreProjector: self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device) if self.proj_type == 'std': - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right') self.svd_count += 1 @@ -39,7 +39,7 @@ class GaLoreProjector: self.svd_count += 1 low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) elif self.proj_type == 'reverse_std': - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left') self.svd_count += 1 @@ -70,12 +70,12 @@ class GaLoreProjector: def project_back(self, low_rank_grad): if self.proj_type == 'std': - if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] >= low_rank_grad.shape[1]: full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) else: full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) elif self.proj_type == 'reverse_std': - if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std + if low_rank_grad.dim() > 1 and low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) else: full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) @@ -170,9 +170,8 @@ class GradientProjector: self.seed = seed def project(self, full_rank_grad, iter): - if self.proj_type == "std": - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix( full_rank_grad, self.rank, type="right", seed=self.seed @@ -188,7 +187,7 @@ class GradientProjector: low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) elif self.proj_type == "reverse_std": - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix( full_rank_grad, self.rank, type="left", seed=self.seed @@ -263,27 +262,6 @@ class GradientProjector: raise ValueError("type should be left, right or full") class Apollo(Optimizer): - """ - Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay - Regularization](https://arxiv.org/abs/1711.05101). - - Parameters: - params (`Iterable[nn.parameter.Parameter]`): - Iterable of parameters to optimize or dictionaries defining parameter groups. - lr (`float`, *optional*, defaults to 0.001): - The learning rate to use. - betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): - Adam's betas parameters (b1, b2). - eps (`float`, *optional*, defaults to 1e-06): - Adam's epsilon for numerical stability. - weight_decay (`float`, *optional*, defaults to 0.0): - Decoupled weight decay to apply. - correct_bias (`bool`, *optional*, defaults to `True`): - Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). - no_deprecation_warning (`bool`, *optional*, defaults to `False`): - A flag used to disable the deprecation warning (set to `True` to disable the warning). - """ - def __init__( self, params: Iterable[nn.parameter.Parameter], @@ -292,7 +270,13 @@ class Apollo(Optimizer): eps: float = 1e-6, weight_decay: float = 0.0, correct_bias: bool = True, - scale_front: bool = False, + + rank: int = 256, + proj: str = "random", + scale_type: str = "channel", + scale: int = 1, + update_proj_gap: int = 200, + proj_type: str = "std", ): if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") @@ -302,17 +286,30 @@ class Apollo(Optimizer): raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") - defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "correct_bias": correct_bias, + + "rank": rank, + "proj": proj, + "scale_type": scale_type, + "scale": scale, + "update_proj_gap": update_proj_gap, + "proj_type": proj_type, + } super().__init__(params, defaults) - - self.scale_front = scale_front + """ params_idx = 0 for group in self.param_groups: for p in group["params"]: params_idx += 1 if p.requires_grad: self.state[p]["seed"] = params_idx + """ @torch.no_grad() def step(self, closure: Callable = None): @@ -332,42 +329,45 @@ class Apollo(Optimizer): params_idx += 1 if p.grad is None: continue - grad = p.grad + grad = p.grad.data if grad.is_sparse: - raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + raise RuntimeError("APOLLO does not support sparse gradients") state = self.state[p] if "step" not in state: state["step"] = 0 - + if "seed" not in state: state["seed"] = params_idx # GaLore Projection - if "rank" in group: + if group["rank"] > 0: if "projector" not in state: if group["proj"] == "random": - state["projector"] = GradientProjector(group["rank"], + state["projector"] = GradientProjector( + group["rank"], update_proj_gap=group["update_proj_gap"], alpha=group["scale"], proj_type=group["proj_type"], - seed=state["seed"]) + seed=state["seed"] + ) elif group["proj"] == "svd": - state["projector"] = GaLoreProjector(group["rank"], + state["projector"] = GaLoreProjector( + group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], - proj_type=group["proj_type"]) + proj_type=group["proj_type"] + ) grad = state["projector"].project(grad, state["step"]) - # State initialization if "exp_avg" not in state: # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(grad) + state["exp_avg"] = torch.zeros_like(grad).detach() # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(grad) + state["exp_avg_sq"] = torch.zeros_like(grad).detach() exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] @@ -389,9 +389,12 @@ class Apollo(Optimizer): # compute norm gradient norm_grad = exp_avg / denom - if "rank" in group: + if group["rank"] > 0: if group['scale_type'] == 'channel': - norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1 + if norm_grad.dim() > 1: + norm_dim = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1 + else: + norm_dim = 0 scaling_factor = ( torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8) @@ -405,7 +408,7 @@ class Apollo(Optimizer): (torch.norm(grad) + 1e-8) ) - scaling_grad = p.grad * scaling_factor + scaling_grad = p.grad.data * scaling_factor # Use Norm-Growth Limiter in Fira if "scaling_grad" in state: @@ -422,17 +425,9 @@ class Apollo(Optimizer): norm_grad = scaling_grad * np.sqrt(group["scale"]) - p.add_(norm_grad, alpha=-step_size) + p.data.add_(norm_grad, alpha=-step_size) - # Just adding the square of the weights to the loss function is *not* - # the correct way of using L2 regularization/weight decay with Adam, - # since that will interact with the m and v parameters in strange ways. - # - # Instead we want to decay the weights in a manner that doesn't interact - # with the m/v parameters. This is equivalent to adding the square - # of the weights to the loss with plain (non-momentum) SGD. - # Add weight decay at the end (fixed version) if group["weight_decay"] > 0.0: - p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"])) return loss \ No newline at end of file