From cbf6b84e2774e473d52cadd7df22b9af1c52391d Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 23 Feb 2025 19:08:26 -0600 Subject: [PATCH] fixed grad norm and loss scale not reporting for local trainer --- vall_e/engines/base.py | 7 +++++-- vall_e/utils/ext/muon.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 9f134c0..f8ece81 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -248,7 +248,7 @@ class Engine(): self.global_samples += self.batch_size if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0: - torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping) + self._global_grad_norm = torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping) self.global_steps += 1 if self.loss_scaler is not None: @@ -260,6 +260,7 @@ class Engine(): self._get_grad_norm() + # doesn't actually work def _get_grad_norm(self): t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ] self._global_grad_norm = torch.cat(t).norm().item() if len(t) else None @@ -585,7 +586,9 @@ class Engines(dict[str, Engine]): loss_scale = 1 if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None: loss_scale = engine.optimizer.loss_scale - + elif engine.loss_scaler is not None: + loss_scale = engine.loss_scaler.get_scale() + if grad_norm is not None: grad_norm /= loss_scale diff --git a/vall_e/utils/ext/muon.py b/vall_e/utils/ext/muon.py index cee740b..0bb4a9f 100644 --- a/vall_e/utils/ext/muon.py +++ b/vall_e/utils/ext/muon.py @@ -126,6 +126,10 @@ class Muon(torch.optim.Optimizer): # Muon # ############################ + # this actually doesn't work with deepspeed for the same reason APOLLO required modifications: + # deepspeed's BF16/F16 optimizer wrapper modifies the tensors, so self.state loses the right mapping + # can't be assed to figure it out right now since it's not easy to fix like APOLLO + params = [p for p in group["params"] if self.state[p]["use_muon"]] # import pdb; pdb.set_trace() lr = group["lr"]