ugh
This commit is contained in:
parent
3dd31e74d1
commit
2ba6b483dc
|
@ -300,9 +300,11 @@ class Apollo(Optimizer):
|
||||||
"proj": proj,
|
"proj": proj,
|
||||||
"scale_type": scale_type,
|
"scale_type": scale_type,
|
||||||
"scale": scale,
|
"scale": scale,
|
||||||
"update_proj_gap": update_proj_gap,
|
|
||||||
"proj_type": proj_type,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.update_proj_gap = update_proj_gap
|
||||||
|
self.proj_type = proj_type
|
||||||
|
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
|
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
|
||||||
|
@ -343,18 +345,18 @@ class Apollo(Optimizer):
|
||||||
if group["proj"] == "random":
|
if group["proj"] == "random":
|
||||||
state["projector"] = GradientProjector(
|
state["projector"] = GradientProjector(
|
||||||
group["rank"],
|
group["rank"],
|
||||||
update_proj_gap=group["update_proj_gap"],
|
update_proj_gap=self.update_proj_gap,
|
||||||
|
proj_type=self.proj_type,
|
||||||
alpha=group["scale"],
|
alpha=group["scale"],
|
||||||
proj_type=group["proj_type"],
|
|
||||||
seed=state["seed"]
|
seed=state["seed"]
|
||||||
)
|
)
|
||||||
|
|
||||||
elif group["proj"] == "svd":
|
elif group["proj"] == "svd":
|
||||||
state["projector"] = GaLoreProjector(
|
state["projector"] = GaLoreProjector(
|
||||||
group["rank"],
|
group["rank"],
|
||||||
update_proj_gap=group["update_proj_gap"],
|
update_proj_gap=self.update_proj_gap,
|
||||||
|
proj_type=self.proj_type,
|
||||||
scale=group["scale"],
|
scale=group["scale"],
|
||||||
proj_type=group["proj_type"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
grad = state["projector"].project(grad, state["step"])
|
grad = state["projector"].project(grad, state["step"])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user