nan loss detection (should have added it earlier), loss scaling for local backend + fp16
This commit is contained in:
parent
a755eb3c62
commit
856545f8bb
|
@ -544,6 +544,13 @@ class Trainer:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def scale_loss(self):
|
||||||
|
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
||||||
|
if self.backend != "local":
|
||||||
|
return False
|
||||||
|
return self.dtype == torch.float16
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
|
|
|
@ -67,6 +67,9 @@ class Engine():
|
||||||
|
|
||||||
self._frozen_params = set()
|
self._frozen_params = set()
|
||||||
|
|
||||||
|
self.max_nan_losses = 8
|
||||||
|
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
# set to freeze
|
# set to freeze
|
||||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||||
|
@ -186,8 +189,11 @@ class Engine():
|
||||||
return self.module.forward(*args, **kwargs)
|
return self.module.forward(*args, **kwargs)
|
||||||
|
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
|
if self.loss_scaler is not None:
|
||||||
|
return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward()
|
||||||
return (loss / self.gradient_accumulation_steps).backward()
|
return (loss / self.gradient_accumulation_steps).backward()
|
||||||
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
with torch.set_grad_enabled(self.gradient_accumulation_steps > 1):
|
||||||
self.micro_steps += 1
|
self.micro_steps += 1
|
||||||
|
@ -197,7 +203,11 @@ class Engine():
|
||||||
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
|
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
|
||||||
|
|
||||||
self.global_steps += 1
|
self.global_steps += 1
|
||||||
self.optimizer.step()
|
if self.loss_scaler is not None:
|
||||||
|
self.loss_scaler.step(self.optimizer)
|
||||||
|
self.loss_scaler.update()
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
self._get_grad_norm()
|
self._get_grad_norm()
|
||||||
|
@ -232,6 +242,11 @@ class Engine():
|
||||||
losses = self.gather_attribute("loss")
|
losses = self.gather_attribute("loss")
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
|
if torch.isnan(loss).any():
|
||||||
|
self.max_nan_losses = self.max_nan_losses - 1
|
||||||
|
if self.max_nan_losses < 0:
|
||||||
|
raise RuntimeError("Too many NaN losses detected.")
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
stats |= self.gather_attribute("scalar")
|
stats |= self.gather_attribute("scalar")
|
||||||
|
|
|
@ -60,6 +60,8 @@ class Engine(DeepSpeedEngine):
|
||||||
self.global_samples = stats["global_samples"]
|
self.global_samples = stats["global_samples"]
|
||||||
self.tokens_processed = stats["tokens_processed"]
|
self.tokens_processed = stats["tokens_processed"]
|
||||||
|
|
||||||
|
self.max_nan_losses = 8
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||||
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
|
raise Exception("freeze_all=False yet self._cfg.frozen_params is None")
|
||||||
|
@ -113,6 +115,11 @@ class Engine(DeepSpeedEngine):
|
||||||
losses = self.gather_attribute("loss")
|
losses = self.gather_attribute("loss")
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
|
if torch.isnan(loss).any():
|
||||||
|
self.max_nan_losses = self.max_nan_losses - 1
|
||||||
|
if self.max_nan_losses < 0:
|
||||||
|
raise RuntimeError("Too many NaN losses detected.")
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
stats |= self.gather_attribute("scalar")
|
stats |= self.gather_attribute("scalar")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user