From 2ba6b483dcbe69fe5e55b12e3fb2b8f73de43c96 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 14 Dec 2024 22:43:51 -0600 Subject: [PATCH] ugh --- vall_e/utils/ext/apollo.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vall_e/utils/ext/apollo.py b/vall_e/utils/ext/apollo.py index 9dfe8b5..1a0e32e 100644 --- a/vall_e/utils/ext/apollo.py +++ b/vall_e/utils/ext/apollo.py @@ -300,9 +300,11 @@ class Apollo(Optimizer): "proj": proj, "scale_type": scale_type, "scale": scale, - "update_proj_gap": update_proj_gap, - "proj_type": proj_type, } + + self.update_proj_gap = update_proj_gap + self.proj_type = proj_type + super().__init__(params, defaults) # do NOT do anything to the params afterwards or it'll cause some desync in deepspeed's optimizer wrapper or something cringe @@ -343,18 +345,18 @@ class Apollo(Optimizer): if group["proj"] == "random": state["projector"] = GradientProjector( group["rank"], - update_proj_gap=group["update_proj_gap"], + update_proj_gap=self.update_proj_gap, + proj_type=self.proj_type, alpha=group["scale"], - proj_type=group["proj_type"], seed=state["seed"] ) elif group["proj"] == "svd": state["projector"] = GaLoreProjector( group["rank"], - update_proj_gap=group["update_proj_gap"], + update_proj_gap=self.update_proj_gap, + proj_type=self.proj_type, scale=group["scale"], - proj_type=group["proj_type"] ) grad = state["projector"].project(grad, state["step"])