remove nan checks because it causes problems in distributed training because I'm not syncing between GPUs (and nan losses gets ignored anyways with loss scaling)

This commit is contained in:
mrq 2024-12-15 09:42:54 -06:00
parent 2ba6b483dc
commit 4800e7179a
3 changed files with 6 additions and 12 deletions

View File

@ -76,7 +76,6 @@ class Engine():
self._frozen_params = set()
self.max_nan_losses = 8
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
self.current_batch_size = 0
@ -294,12 +293,6 @@ class Engine():
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
return stats
self.backward(loss)
self.step()
@ -545,12 +538,16 @@ class Engines(dict[str, Engine]):
raise RuntimeError("Out of memory during forward pass!")
"""
# this causes problems in distributed training
# it's probably required to do all_reduce for nan checks
"""
# no results are returned when a nan is encountered, so catch it here too
if res is None:
engine.max_nan_losses = engine.max_nan_losses - 1
if engine.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
continue
"""
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")

View File

@ -62,9 +62,7 @@ class Engine(DeepSpeedEngine):
self.global_samples = stats["global_samples"]
self.tokens_processed = stats["tokens_processed"]
self.max_nan_losses = 8
self.current_batch_size = 0
self.skip_on_nan = True
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
@ -151,12 +149,14 @@ class Engine(DeepSpeedEngine):
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
"""
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
return stats
"""
self.backward(loss)
self.step()

View File

@ -101,9 +101,6 @@ def train_feeder(engine, batch, teacher=None):
loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
return
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}