From 2dd80a03ff77464b3d11de3d30ef07de7cbdf1b6 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 6 Mar 2025 17:07:29 -0600 Subject: [PATCH] stuff for interfacing with the loss scaler value (because I want to cap it) --- vall_e/config.py | 1 + vall_e/engines/base.py | 30 +++++++++++++++++++++++++----- vall_e/engines/deepspeed.py | 13 +++++++++++++ vall_e/utils/trainer.py | 11 ++++++++++- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index c64880a..e4c29cb 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -569,6 +569,7 @@ class DeepSpeed: loss_scale_window: int = 1000 min_loss_scale: float = 32768.0 + max_loss_scale: float = 1048576.0 loss_scale = 0.0 config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 7d63f56..7d9388e 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -282,6 +282,20 @@ class Engine(): elif 'lr' in param_group: param_group['lr'] = lr + def get_loss_scale(self): + if not hasattr(self, "loss_scaler") or self.loss_scaler is None: + return 1 + + return self.loss_scaler.get_scale() + + def set_loss_scale(self, value): + if not hasattr(self, "loss_scaler") or self.loss_scaler is None: + return + + """ + self.optimizer.loss_scale = value + """ + def get_global_grad_norm(self): return self._global_grad_norm @@ -457,6 +471,12 @@ class Engines(dict[str, Engine]): continue engine.set_lr(lr) + def set_loss_scale(self, lr): + for engine in self.values(): + if not engine._training: + continue + engine.set_loss_scale(lr) + def _update(self): for engine in self.values(): self._global_step = max(self._global_step, engine.global_step) @@ -584,11 +604,11 @@ class Engines(dict[str, Engine]): elapsed_time = time.time() - start_time total_elapsed_time += elapsed_time grad_norm = engine.get_global_grad_norm() - loss_scale = 1 - if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None: - loss_scale = engine.optimizer.loss_scale - elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None: - loss_scale = engine.loss_scaler.get_scale() + loss_scale = engine.get_loss_scale() + + if cfg.trainer.deepspeed.max_loss_scale > 0 and loss_scale > cfg.trainer.deepspeed.max_loss_scale: + _logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...') + engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale) if grad_norm is not None: grad_norm /= loss_scale diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 85ec63b..27e4276 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -119,6 +119,19 @@ class Engine(DeepSpeedEngine): except Exception as e: _logger.warning(str(e)) + # cur_scale, because _get_loss_scale has a typo in the def and I can't be assed to inject a fix into it or push a PR + def get_loss_scale(self): + if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None: + return 1.0 + + return self.optimizer.cur_scale + + def set_loss_scale(self, value): + if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None: + return + + self.optimizer.cur_scale = value + # we'll just have to live with the LoRA weights living within our main weights # they're easy to extract anyways def load_checkpoint(self, load_dir, **kwargs ): diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 09609e7..a563e29 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -233,7 +233,16 @@ def train( engines.set_lr(rate) _logger.info(f"Updating LR to: {rate}") except Exception as e: - _logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}") + _logger.warning(f"Failed to set LR to: {rate}, {str(e)}") + + if "loss_scale" in command: + value = float(command.split(" ")[-1]) + try: + engines.set_loss_scale(value) + _logger.info(f"Updating loss scale to: {value}") + except Exception as e: + raise e + _logger.warning(f"Failed to set loss scale to: {value}, {str(e)}") if "export" in command: train_dl.dataset.save_state_dict()