This commit is contained in:
mrq 2024-12-14 22:43:51 -06:00
parent 3dd31e74d1
commit 2ba6b483dc

View File

@ -300,9 +300,11 @@ class Apollo(Optimizer):
"proj": proj,
"scale_type": scale_type,
"scale": scale,
"update_proj_gap": update_proj_gap,
"proj_type": proj_type,
}
self.update_proj_gap = update_proj_gap
self.proj_type = proj_type
super().__init__(params, defaults)
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
@ -343,18 +345,18 @@ class Apollo(Optimizer):
if group["proj"] == "random":
state["projector"] = GradientProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
update_proj_gap=self.update_proj_gap,
proj_type=self.proj_type,
alpha=group["scale"],
proj_type=group["proj_type"],
seed=state["seed"]
)
elif group["proj"] == "svd":
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
update_proj_gap=self.update_proj_gap,
proj_type=self.proj_type,
scale=group["scale"],
proj_type=group["proj_type"]
)
grad = state["projector"].project(grad, state["step"])