diff --git a/vall_e/utils/ext/apollo.py b/vall_e/utils/ext/apollo.py index ae08e87..e029bca 100644 --- a/vall_e/utils/ext/apollo.py +++ b/vall_e/utils/ext/apollo.py @@ -23,7 +23,6 @@ class GaLoreProjector: self.svd_count = 0 def project(self, full_rank_grad, iter): - if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device: self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device) @@ -170,6 +169,9 @@ class GradientProjector: self.seed = seed def project(self, full_rank_grad, iter): + if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device: + self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device) + if self.proj_type == "std": if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: @@ -302,14 +304,7 @@ class Apollo(Optimizer): } super().__init__(params, defaults) - """ - params_idx = 0 - for group in self.param_groups: - for p in group["params"]: - params_idx += 1 - if p.requires_grad: - self.state[p]["seed"] = params_idx - """ + # do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe @torch.no_grad() def step(self, closure: Callable = None): @@ -329,7 +324,7 @@ class Apollo(Optimizer): params_idx += 1 if p.grad is None: continue - grad = p.grad.data + grad = p.grad if grad.is_sparse: raise RuntimeError("APOLLO does not support sparse gradients") @@ -365,9 +360,9 @@ class Apollo(Optimizer): if "exp_avg" not in state: # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(grad).detach() + state["exp_avg"] = torch.zeros_like(grad) # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(grad).detach() + state["exp_avg_sq"] = torch.zeros_like(grad) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] @@ -408,7 +403,7 @@ class Apollo(Optimizer): (torch.norm(grad) + 1e-8) ) - scaling_grad = p.grad.data * scaling_factor + scaling_grad = p.grad * scaling_factor # Use Norm-Growth Limiter in Fira if "scaling_grad" in state: @@ -425,9 +420,9 @@ class Apollo(Optimizer): norm_grad = scaling_grad * np.sqrt(group["scale"]) - p.data.add_(norm_grad, alpha=-step_size) + p.add_(norm_grad, alpha=-step_size) if group["weight_decay"] > 0.0: - p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) return loss \ No newline at end of file