stuff for interfacing with the loss scaler value (because I want to cap it)
This commit is contained in:
parent
a30dffcca7
commit
2dd80a03ff
|
@ -569,6 +569,7 @@ class DeepSpeed:
|
||||||
|
|
||||||
loss_scale_window: int = 1000
|
loss_scale_window: int = 1000
|
||||||
min_loss_scale: float = 32768.0
|
min_loss_scale: float = 32768.0
|
||||||
|
max_loss_scale: float = 1048576.0
|
||||||
loss_scale = 0.0
|
loss_scale = 0.0
|
||||||
|
|
||||||
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
|
@ -282,6 +282,20 @@ class Engine():
|
||||||
elif 'lr' in param_group:
|
elif 'lr' in param_group:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
def get_loss_scale(self):
|
||||||
|
if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return self.loss_scaler.get_scale()
|
||||||
|
|
||||||
|
def set_loss_scale(self, value):
|
||||||
|
if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.optimizer.loss_scale = value
|
||||||
|
"""
|
||||||
|
|
||||||
def get_global_grad_norm(self):
|
def get_global_grad_norm(self):
|
||||||
return self._global_grad_norm
|
return self._global_grad_norm
|
||||||
|
|
||||||
|
@ -457,6 +471,12 @@ class Engines(dict[str, Engine]):
|
||||||
continue
|
continue
|
||||||
engine.set_lr(lr)
|
engine.set_lr(lr)
|
||||||
|
|
||||||
|
def set_loss_scale(self, lr):
|
||||||
|
for engine in self.values():
|
||||||
|
if not engine._training:
|
||||||
|
continue
|
||||||
|
engine.set_loss_scale(lr)
|
||||||
|
|
||||||
def _update(self):
|
def _update(self):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
self._global_step = max(self._global_step, engine.global_step)
|
self._global_step = max(self._global_step, engine.global_step)
|
||||||
|
@ -584,11 +604,11 @@ class Engines(dict[str, Engine]):
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
total_elapsed_time += elapsed_time
|
total_elapsed_time += elapsed_time
|
||||||
grad_norm = engine.get_global_grad_norm()
|
grad_norm = engine.get_global_grad_norm()
|
||||||
loss_scale = 1
|
loss_scale = engine.get_loss_scale()
|
||||||
if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
|
|
||||||
loss_scale = engine.optimizer.loss_scale
|
if cfg.trainer.deepspeed.max_loss_scale > 0 and loss_scale > cfg.trainer.deepspeed.max_loss_scale:
|
||||||
elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None:
|
_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
|
||||||
loss_scale = engine.loss_scaler.get_scale()
|
engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
|
||||||
|
|
||||||
if grad_norm is not None:
|
if grad_norm is not None:
|
||||||
grad_norm /= loss_scale
|
grad_norm /= loss_scale
|
||||||
|
|
|
@ -119,6 +119,19 @@ class Engine(DeepSpeedEngine):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(str(e))
|
_logger.warning(str(e))
|
||||||
|
|
||||||
|
# cur_scale, because _get_loss_scale has a typo in the def and I can't be assed to inject a fix into it or push a PR
|
||||||
|
def get_loss_scale(self):
|
||||||
|
if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return self.optimizer.cur_scale
|
||||||
|
|
||||||
|
def set_loss_scale(self, value):
|
||||||
|
if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.optimizer.cur_scale = value
|
||||||
|
|
||||||
# we'll just have to live with the LoRA weights living within our main weights
|
# we'll just have to live with the LoRA weights living within our main weights
|
||||||
# they're easy to extract anyways
|
# they're easy to extract anyways
|
||||||
def load_checkpoint(self, load_dir, **kwargs ):
|
def load_checkpoint(self, load_dir, **kwargs ):
|
||||||
|
|
|
@ -233,7 +233,16 @@ def train(
|
||||||
engines.set_lr(rate)
|
engines.set_lr(rate)
|
||||||
_logger.info(f"Updating LR to: {rate}")
|
_logger.info(f"Updating LR to: {rate}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
|
_logger.warning(f"Failed to set LR to: {rate}, {str(e)}")
|
||||||
|
|
||||||
|
if "loss_scale" in command:
|
||||||
|
value = float(command.split(" ")[-1])
|
||||||
|
try:
|
||||||
|
engines.set_loss_scale(value)
|
||||||
|
_logger.info(f"Updating loss scale to: {value}")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
_logger.warning(f"Failed to set loss scale to: {value}, {str(e)}")
|
||||||
|
|
||||||
if "export" in command:
|
if "export" in command:
|
||||||
train_dl.dataset.save_state_dict()
|
train_dl.dataset.save_state_dict()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user