nan loss detection (should have added it earlier), loss scaling for local backend + fp16
This commit is contained in:
parent
a755eb3c62
commit
856545f8bb
|
@ -544,6 +544,13 @@ class Trainer:
|
|||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
@cached_property
|
||||
def scale_loss(self):
|
||||
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
||||
if self.backend != "local":
|
||||
return False
|
||||
return self.dtype == torch.float16
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
|
|
|
@ -67,6 +67,9 @@ class Engine():
|
|||
|
||||
self._frozen_params = set()
|
||||
|
||||
self.max_nan_losses = 8
|
||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
# set to freeze
|
||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||
|
@ -186,8 +189,11 @@ class Engine():
|
|||
return self.module.forward(*args, **kwargs)
|
||||
|
||||
def backward(self, loss):
|
||||
if self.loss_scaler is not None:
|
||||
return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
|
||||
return (loss / self.gradient_accumulation_steps).backward()
|
||||
|
||||
|
||||
def step(self):
|
||||
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
||||
self.micro_steps += 1
|
||||
|
@ -197,7 +203,11 @@ class Engine():
|
|||
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
|
||||
|
||||
self.global_steps += 1
|
||||
self.optimizer.step()
|
||||
if self.loss_scaler is not None:
|
||||
self.loss_scaler.step(self.optimizer)
|
||||
self.loss_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
self._get_grad_norm()
|
||||
|
@ -232,6 +242,11 @@ class Engine():
|
|||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
|
|
@ -60,6 +60,8 @@ class Engine(DeepSpeedEngine):
|
|||
self.global_samples = stats["global_samples"]
|
||||
self.tokens_processed = stats["tokens_processed"]
|
||||
|
||||
self.max_nan_losses = 8
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
|
||||
|
@ -113,6 +115,11 @@ class Engine(DeepSpeedEngine):
|
|||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
|
Loading…
Reference in New Issue
Block a user