move lazy-stored ortho matrix to the grad device for apollo because agony
This commit is contained in:
parent
09804ecc16
commit
35389481ee
|
@ -23,7 +23,6 @@ class GaLoreProjector:
|
|||
self.svd_count = 0
|
||||
|
||||
def project(self, full_rank_grad, iter):
|
||||
|
||||
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
|
||||
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||
|
||||
|
@ -170,6 +169,9 @@ class GradientProjector:
|
|||
self.seed = seed
|
||||
|
||||
def project(self, full_rank_grad, iter):
|
||||
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
|
||||
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||
|
||||
if self.proj_type == "std":
|
||||
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
|
@ -302,14 +304,7 @@ class Apollo(Optimizer):
|
|||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
"""
|
||||
params_idx = 0
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
params_idx += 1
|
||||
if p.requires_grad:
|
||||
self.state[p]["seed"] = params_idx
|
||||
"""
|
||||
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Callable = None):
|
||||
|
@ -329,7 +324,7 @@ class Apollo(Optimizer):
|
|||
params_idx += 1
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("APOLLO does not support sparse gradients")
|
||||
|
||||
|
@ -365,9 +360,9 @@ class Apollo(Optimizer):
|
|||
|
||||
if "exp_avg" not in state:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad).detach()
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad).detach()
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
@ -408,7 +403,7 @@ class Apollo(Optimizer):
|
|||
(torch.norm(grad) + 1e-8)
|
||||
)
|
||||
|
||||
scaling_grad = p.grad.data * scaling_factor
|
||||
scaling_grad = p.grad * scaling_factor
|
||||
|
||||
# Use Norm-Growth Limiter in Fira
|
||||
if "scaling_grad" in state:
|
||||
|
@ -425,9 +420,9 @@ class Apollo(Optimizer):
|
|||
|
||||
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
||||
|
||||
p.data.add_(norm_grad, alpha=-step_size)
|
||||
p.add_(norm_grad, alpha=-step_size)
|
||||
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
return loss
|
Loading…
Reference in New Issue
Block a user