move lazy-stored ortho matrix to the grad device for apollo because agony

This commit is contained in:
mrq 2024-12-13 23:22:26 -06:00
parent 09804ecc16
commit 35389481ee

View File

@ -23,7 +23,6 @@ class GaLoreProjector:
self.svd_count = 0 self.svd_count = 0
def project(self, full_rank_grad, iter): def project(self, full_rank_grad, iter):
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device: if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device) self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
@ -170,6 +169,9 @@ class GradientProjector:
self.seed = seed self.seed = seed
def project(self, full_rank_grad, iter): def project(self, full_rank_grad, iter):
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
if self.proj_type == "std": if self.proj_type == "std":
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0: if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
@ -302,14 +304,7 @@ class Apollo(Optimizer):
} }
super().__init__(params, defaults) super().__init__(params, defaults)
""" # do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
params_idx = 0
for group in self.param_groups:
for p in group["params"]:
params_idx += 1
if p.requires_grad:
self.state[p]["seed"] = params_idx
"""
@torch.no_grad() @torch.no_grad()
def step(self, closure: Callable = None): def step(self, closure: Callable = None):
@ -329,7 +324,7 @@ class Apollo(Optimizer):
params_idx += 1 params_idx += 1
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError("APOLLO does not support sparse gradients") raise RuntimeError("APOLLO does not support sparse gradients")
@ -365,9 +360,9 @@ class Apollo(Optimizer):
if "exp_avg" not in state: if "exp_avg" not in state:
# Exponential moving average of gradient values # Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad).detach() state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad).detach() state["exp_avg_sq"] = torch.zeros_like(grad)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
@ -408,7 +403,7 @@ class Apollo(Optimizer):
(torch.norm(grad) + 1e-8) (torch.norm(grad) + 1e-8)
) )
scaling_grad = p.grad.data * scaling_factor scaling_grad = p.grad * scaling_factor
# Use Norm-Growth Limiter in Fira # Use Norm-Growth Limiter in Fira
if "scaling_grad" in state: if "scaling_grad" in state:
@ -425,9 +420,9 @@ class Apollo(Optimizer):
norm_grad = scaling_grad * np.sqrt(group["scale"]) norm_grad = scaling_grad * np.sqrt(group["scale"])
p.data.add_(norm_grad, alpha=-step_size) p.add_(norm_grad, alpha=-step_size)
if group["weight_decay"] > 0.0: if group["weight_decay"] > 0.0:
p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"])) p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss return loss