stuff for interfacing with the loss scaler value (because I want to cap it)
This commit is contained in:
parent
a30dffcca7
commit
2dd80a03ff
|
@ -569,6 +569,7 @@ class DeepSpeed:
|
|||
|
||||
loss_scale_window: int = 1000
|
||||
min_loss_scale: float = 32768.0
|
||||
max_loss_scale: float = 1048576.0
|
||||
loss_scale = 0.0
|
||||
|
||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
|
|
@ -282,6 +282,20 @@ class Engine():
|
|||
elif 'lr' in param_group:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def get_loss_scale(self):
|
||||
if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
|
||||
return 1
|
||||
|
||||
return self.loss_scaler.get_scale()
|
||||
|
||||
def set_loss_scale(self, value):
|
||||
if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
|
||||
return
|
||||
|
||||
"""
|
||||
self.optimizer.loss_scale = value
|
||||
"""
|
||||
|
||||
def get_global_grad_norm(self):
|
||||
return self._global_grad_norm
|
||||
|
||||
|
@ -457,6 +471,12 @@ class Engines(dict[str, Engine]):
|
|||
continue
|
||||
engine.set_lr(lr)
|
||||
|
||||
def set_loss_scale(self, lr):
|
||||
for engine in self.values():
|
||||
if not engine._training:
|
||||
continue
|
||||
engine.set_loss_scale(lr)
|
||||
|
||||
def _update(self):
|
||||
for engine in self.values():
|
||||
self._global_step = max(self._global_step, engine.global_step)
|
||||
|
@ -584,11 +604,11 @@ class Engines(dict[str, Engine]):
|
|||
elapsed_time = time.time() - start_time
|
||||
total_elapsed_time += elapsed_time
|
||||
grad_norm = engine.get_global_grad_norm()
|
||||
loss_scale = 1
|
||||
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
|
||||
loss_scale = engine.optimizer.loss_scale
|
||||
elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None:
|
||||
loss_scale = engine.loss_scaler.get_scale()
|
||||
loss_scale = engine.get_loss_scale()
|
||||
|
||||
if cfg.trainer.deepspeed.max_loss_scale > 0 and loss_scale > cfg.trainer.deepspeed.max_loss_scale:
|
||||
_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
|
||||
engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
|
||||
|
||||
if grad_norm is not None:
|
||||
grad_norm /= loss_scale
|
||||
|
|
|
@ -119,6 +119,19 @@ class Engine(DeepSpeedEngine):
|
|||
except Exception as e:
|
||||
_logger.warning(str(e))
|
||||
|
||||
# cur_scale, because _get_loss_scale has a typo in the def and I can't be assed to inject a fix into it or push a PR
|
||||
def get_loss_scale(self):
|
||||
if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
|
||||
return 1.0
|
||||
|
||||
return self.optimizer.cur_scale
|
||||
|
||||
def set_loss_scale(self, value):
|
||||
if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
|
||||
return
|
||||
|
||||
self.optimizer.cur_scale = value
|
||||
|
||||
# we'll just have to live with the LoRA weights living within our main weights
|
||||
# they're easy to extract anyways
|
||||
def load_checkpoint(self, load_dir, **kwargs ):
|
||||
|
|
|
@ -233,7 +233,16 @@ def train(
|
|||
engines.set_lr(rate)
|
||||
_logger.info(f"Updating LR to: {rate}")
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
|
||||
_logger.warning(f"Failed to set LR to: {rate}, {str(e)}")
|
||||
|
||||
if "loss_scale" in command:
|
||||
value = float(command.split(" ")[-1])
|
||||
try:
|
||||
engines.set_loss_scale(value)
|
||||
_logger.info(f"Updating loss scale to: {value}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
_logger.warning(f"Failed to set loss scale to: {value}, {str(e)}")
|
||||
|
||||
if "export" in command:
|
||||
train_dl.dataset.save_state_dict()
|
||||
|
|
Loading…
Reference in New Issue
Block a user