move lazy-stored ortho matrix to the grad device for apollo because agony

This commit is contained in:
mrq 2024-12-13 23:22:26 -06:00
parent 09804ecc16
commit 35389481ee

View File

@ -23,7 +23,6 @@ class GaLoreProjector:
self.svd_count = 0
def project(self, full_rank_grad, iter):
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
@ -170,6 +169,9 @@ class GradientProjector:
self.seed = seed
def project(self, full_rank_grad, iter):
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
if self.proj_type == "std":
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
@ -302,14 +304,7 @@ class Apollo(Optimizer):
}
super().__init__(params, defaults)
"""
params_idx = 0
for group in self.param_groups:
for p in group["params"]:
params_idx += 1
if p.requires_grad:
self.state[p]["seed"] = params_idx
"""
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
@torch.no_grad()
def step(self, closure: Callable = None):
@ -329,7 +324,7 @@ class Apollo(Optimizer):
params_idx += 1
if p.grad is None:
continue
grad = p.grad.data
grad = p.grad
if grad.is_sparse:
raise RuntimeError("APOLLO does not support sparse gradients")
@ -365,9 +360,9 @@ class Apollo(Optimizer):
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad).detach()
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad).detach()
state["exp_avg_sq"] = torch.zeros_like(grad)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
@ -408,7 +403,7 @@ class Apollo(Optimizer):
(torch.norm(grad) + 1e-8)
)
scaling_grad = p.grad.data * scaling_factor
scaling_grad = p.grad * scaling_factor
# Use Norm-Growth Limiter in Fira
if "scaling_grad" in state:
@ -425,9 +420,9 @@ class Apollo(Optimizer):
norm_grad = scaling_grad * np.sqrt(group["scale"])
p.data.add_(norm_grad, alpha=-step_size)
p.add_(norm_grad, alpha=-step_size)
if group["weight_decay"] > 0.0:
p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss