This commit is contained in:
mrq 2024-12-14 22:43:51 -06:00
parent 3dd31e74d1
commit 2ba6b483dc

View File

@ -300,9 +300,11 @@ class Apollo(Optimizer):
"proj": proj, "proj": proj,
"scale_type": scale_type, "scale_type": scale_type,
"scale": scale, "scale": scale,
"update_proj_gap": update_proj_gap,
"proj_type": proj_type,
} }
self.update_proj_gap = update_proj_gap
self.proj_type = proj_type
super().__init__(params, defaults) super().__init__(params, defaults)
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe # do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
@ -343,18 +345,18 @@ class Apollo(Optimizer):
if group["proj"] == "random": if group["proj"] == "random":
state["projector"] = GradientProjector( state["projector"] = GradientProjector(
group["rank"], group["rank"],
update_proj_gap=group["update_proj_gap"], update_proj_gap=self.update_proj_gap,
proj_type=self.proj_type,
alpha=group["scale"], alpha=group["scale"],
proj_type=group["proj_type"],
seed=state["seed"] seed=state["seed"]
) )
elif group["proj"] == "svd": elif group["proj"] == "svd":
state["projector"] = GaLoreProjector( state["projector"] = GaLoreProjector(
group["rank"], group["rank"],
update_proj_gap=group["update_proj_gap"], update_proj_gap=self.update_proj_gap,
proj_type=self.proj_type,
scale=group["scale"], scale=group["scale"],
proj_type=group["proj_type"]
) )
grad = state["projector"].project(grad, state["step"]) grad = state["projector"].project(grad, state["step"])