ugh
This commit is contained in:
parent
3dd31e74d1
commit
2ba6b483dc
|
@ -300,9 +300,11 @@ class Apollo(Optimizer):
|
|||
"proj": proj,
|
||||
"scale_type": scale_type,
|
||||
"scale": scale,
|
||||
"update_proj_gap": update_proj_gap,
|
||||
"proj_type": proj_type,
|
||||
}
|
||||
|
||||
self.update_proj_gap = update_proj_gap
|
||||
self.proj_type = proj_type
|
||||
|
||||
super().__init__(params, defaults)
|
||||
|
||||
# do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe
|
||||
|
@ -343,18 +345,18 @@ class Apollo(Optimizer):
|
|||
if group["proj"] == "random":
|
||||
state["projector"] = GradientProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
update_proj_gap=self.update_proj_gap,
|
||||
proj_type=self.proj_type,
|
||||
alpha=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
seed=state["seed"]
|
||||
)
|
||||
|
||||
elif group["proj"] == "svd":
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
update_proj_gap=self.update_proj_gap,
|
||||
proj_type=self.proj_type,
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"]
|
||||
)
|
||||
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
|
Loading…
Reference in New Issue
Block a user