nan loss detection (should have added it earlier), loss scaling for local backend + fp16

This commit is contained in:
mrq 2024-05-11 22:23:29 -05:00
parent a755eb3c62
commit 856545f8bb
3 changed files with 30 additions and 1 deletions

View File

@ -544,6 +544,13 @@ class Trainer:
return torch.float8_e4m3fn
return torch.float32
@cached_property
def scale_loss(self):
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
if self.backend != "local":
return False
return self.dtype == torch.float16
@dataclass()
class Inference:

View File

@ -67,6 +67,9 @@ class Engine():
self._frozen_params = set()
self.max_nan_losses = 8
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
def freeze(self, freeze_all=True):
# set to freeze
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
@ -186,8 +189,11 @@ class Engine():
return self.module.forward(*args, **kwargs)
def backward(self, loss):
if self.loss_scaler is not None:
return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
return (loss / self.gradient_accumulation_steps).backward()
def step(self):
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
self.micro_steps += 1
@ -197,7 +203,11 @@ class Engine():
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
self.global_steps += 1
self.optimizer.step()
if self.loss_scaler is not None:
self.loss_scaler.step(self.optimizer)
self.loss_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self._get_grad_norm()
@ -232,6 +242,11 @@ class Engine():
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")

View File

@ -60,6 +60,8 @@ class Engine(DeepSpeedEngine):
self.global_samples = stats["global_samples"]
self.tokens_processed = stats["tokens_processed"]
self.max_nan_losses = 8
def freeze(self, freeze_all=True):
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
@ -113,6 +115,11 @@ class Engine(DeepSpeedEngine):
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")