stuff for interfacing with the loss scaler value (because I want to cap it)

This commit is contained in:
mrq 2025-03-06 17:07:29 -06:00
parent a30dffcca7
commit 2dd80a03ff
4 changed files with 49 additions and 6 deletions

View File

@ -569,6 +569,7 @@ class DeepSpeed:
loss_scale_window: int = 1000
min_loss_scale: float = 32768.0
max_loss_scale: float = 1048576.0
loss_scale = 0.0
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config

View File

@ -282,6 +282,20 @@ class Engine():
elif 'lr' in param_group:
param_group['lr'] = lr
def get_loss_scale(self):
if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
return 1
return self.loss_scaler.get_scale()
def set_loss_scale(self, value):
if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
return
"""
self.optimizer.loss_scale = value
"""
def get_global_grad_norm(self):
return self._global_grad_norm
@ -457,6 +471,12 @@ class Engines(dict[str, Engine]):
continue
engine.set_lr(lr)
def set_loss_scale(self, lr):
for engine in self.values():
if not engine._training:
continue
engine.set_loss_scale(lr)
def _update(self):
for engine in self.values():
self._global_step = max(self._global_step, engine.global_step)
@ -584,11 +604,11 @@ class Engines(dict[str, Engine]):
elapsed_time = time.time() - start_time
total_elapsed_time += elapsed_time
grad_norm = engine.get_global_grad_norm()
loss_scale = 1
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
loss_scale = engine.optimizer.loss_scale
elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None:
loss_scale = engine.loss_scaler.get_scale()
loss_scale = engine.get_loss_scale()
if cfg.trainer.deepspeed.max_loss_scale > 0 and loss_scale > cfg.trainer.deepspeed.max_loss_scale:
_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
if grad_norm is not None:
grad_norm /= loss_scale

View File

@ -119,6 +119,19 @@ class Engine(DeepSpeedEngine):
except Exception as e:
_logger.warning(str(e))
# cur_scale, because _get_loss_scale has a typo in the def and I can't be assed to inject a fix into it or push a PR
def get_loss_scale(self):
if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
return 1.0
return self.optimizer.cur_scale
def set_loss_scale(self, value):
if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
return
self.optimizer.cur_scale = value
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways
def load_checkpoint(self, load_dir, **kwargs ):

View File

@ -233,7 +233,16 @@ def train(
engines.set_lr(rate)
_logger.info(f"Updating LR to: {rate}")
except Exception as e:
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
_logger.warning(f"Failed to set LR to: {rate}, {str(e)}")
if "loss_scale" in command:
value = float(command.split(" ")[-1])
try:
engines.set_loss_scale(value)
_logger.info(f"Updating loss scale to: {value}")
except Exception as e:
raise e
_logger.warning(f"Failed to set loss scale to: {value}, {str(e)}")
if "export" in command:
train_dl.dataset.save_state_dict()