move lazy-stored ortho matrix to the grad device for apollo because agony
This commit is contained in:
parent
09804ecc16
commit
35389481ee
|
@ -23,7 +23,6 @@ class GaLoreProjector:
|
||||||
self.svd_count = 0
|
self.svd_count = 0
|
||||||
|
|
||||||
def project(self, full_rank_grad, iter):
|
def project(self, full_rank_grad, iter):
|
||||||
|
|
||||||
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
|
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
|
||||||
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||||
|
|
||||||
|
@ -170,6 +169,9 @@ class GradientProjector:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def project(self, full_rank_grad, iter):
|
def project(self, full_rank_grad, iter):
|
||||||
|
if self.ortho_matrix is not None and self.ortho_matrix.device != full_rank_grad.device:
|
||||||
|
self.ortho_matrix = self.ortho_matrix.to(full_rank_grad.device)
|
||||||
|
|
||||||
if self.proj_type == "std":
|
if self.proj_type == "std":
|
||||||
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
if full_rank_grad.dim() > 1 and full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||||
|
@ -302,14 +304,7 @@ class Apollo(Optimizer):
|
||||||
}
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
"""
|
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
|
||||||
params_idx = 0
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group["params"]:
|
|
||||||
params_idx += 1
|
|
||||||
if p.requires_grad:
|
|
||||||
self.state[p]["seed"] = params_idx
|
|
||||||
"""
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure: Callable = None):
|
def step(self, closure: Callable = None):
|
||||||
|
@ -329,7 +324,7 @@ class Apollo(Optimizer):
|
||||||
params_idx += 1
|
params_idx += 1
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError("APOLLO does not support sparse gradients")
|
raise RuntimeError("APOLLO does not support sparse gradients")
|
||||||
|
|
||||||
|
@ -365,9 +360,9 @@ class Apollo(Optimizer):
|
||||||
|
|
||||||
if "exp_avg" not in state:
|
if "exp_avg" not in state:
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state["exp_avg"] = torch.zeros_like(grad).detach()
|
state["exp_avg"] = torch.zeros_like(grad)
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state["exp_avg_sq"] = torch.zeros_like(grad).detach()
|
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
|
@ -408,7 +403,7 @@ class Apollo(Optimizer):
|
||||||
(torch.norm(grad) + 1e-8)
|
(torch.norm(grad) + 1e-8)
|
||||||
)
|
)
|
||||||
|
|
||||||
scaling_grad = p.grad.data * scaling_factor
|
scaling_grad = p.grad * scaling_factor
|
||||||
|
|
||||||
# Use Norm-Growth Limiter in Fira
|
# Use Norm-Growth Limiter in Fira
|
||||||
if "scaling_grad" in state:
|
if "scaling_grad" in state:
|
||||||
|
@ -425,9 +420,9 @@ class Apollo(Optimizer):
|
||||||
|
|
||||||
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
norm_grad = scaling_grad * np.sqrt(group["scale"])
|
||||||
|
|
||||||
p.data.add_(norm_grad, alpha=-step_size)
|
p.add_(norm_grad, alpha=-step_size)
|
||||||
|
|
||||||
if group["weight_decay"] > 0.0:
|
if group["weight_decay"] > 0.0:
|
||||||
p.data.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||||
|
|
||||||
return loss
|
return loss
|
Loading…
Reference in New Issue
Block a user