remove nan checks because it causes problems in distributed training because I'm not syncing between GPUs (and nan losses gets ignored anyways with loss scaling)
This commit is contained in:
parent
2ba6b483dc
commit
4800e7179a
|
@ -76,7 +76,6 @@ class Engine():
|
||||||
|
|
||||||
self._frozen_params = set()
|
self._frozen_params = set()
|
||||||
|
|
||||||
self.max_nan_losses = 8
|
|
||||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||||
|
|
||||||
self.current_batch_size = 0
|
self.current_batch_size = 0
|
||||||
|
@ -294,12 +293,6 @@ class Engine():
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
stats |= self.gather_attribute("scalar")
|
stats |= self.gather_attribute("scalar")
|
||||||
|
|
||||||
if torch.isnan(loss).any():
|
|
||||||
self.max_nan_losses = self.max_nan_losses - 1
|
|
||||||
if self.max_nan_losses < 0:
|
|
||||||
raise RuntimeError("Too many NaN losses detected.")
|
|
||||||
return stats
|
|
||||||
|
|
||||||
self.backward(loss)
|
self.backward(loss)
|
||||||
self.step()
|
self.step()
|
||||||
|
|
||||||
|
@ -545,12 +538,16 @@ class Engines(dict[str, Engine]):
|
||||||
raise RuntimeError("Out of memory during forward pass!")
|
raise RuntimeError("Out of memory during forward pass!")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# this causes problems in distributed training
|
||||||
|
# it's probably required to do all_reduce for nan checks
|
||||||
|
"""
|
||||||
# no results are returned when a nan is encountered, so catch it here too
|
# no results are returned when a nan is encountered, so catch it here too
|
||||||
if res is None:
|
if res is None:
|
||||||
engine.max_nan_losses = engine.max_nan_losses - 1
|
engine.max_nan_losses = engine.max_nan_losses - 1
|
||||||
if engine.max_nan_losses < 0:
|
if engine.max_nan_losses < 0:
|
||||||
raise RuntimeError("Too many NaN losses detected.")
|
raise RuntimeError("Too many NaN losses detected.")
|
||||||
continue
|
continue
|
||||||
|
"""
|
||||||
|
|
||||||
loss, engine_stats = res
|
loss, engine_stats = res
|
||||||
engine_stats |= self.gather_attribute("scalar")
|
engine_stats |= self.gather_attribute("scalar")
|
||||||
|
|
|
@ -62,9 +62,7 @@ class Engine(DeepSpeedEngine):
|
||||||
self.global_samples = stats["global_samples"]
|
self.global_samples = stats["global_samples"]
|
||||||
self.tokens_processed = stats["tokens_processed"]
|
self.tokens_processed = stats["tokens_processed"]
|
||||||
|
|
||||||
self.max_nan_losses = 8
|
|
||||||
self.current_batch_size = 0
|
self.current_batch_size = 0
|
||||||
self.skip_on_nan = True
|
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
# freeze non-LoRA params if requested
|
# freeze non-LoRA params if requested
|
||||||
|
@ -151,12 +149,14 @@ class Engine(DeepSpeedEngine):
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
stats |= self.gather_attribute("scalar")
|
stats |= self.gather_attribute("scalar")
|
||||||
|
|
||||||
|
"""
|
||||||
if torch.isnan(loss).any():
|
if torch.isnan(loss).any():
|
||||||
self.max_nan_losses = self.max_nan_losses - 1
|
self.max_nan_losses = self.max_nan_losses - 1
|
||||||
if self.max_nan_losses < 0:
|
if self.max_nan_losses < 0:
|
||||||
raise RuntimeError("Too many NaN losses detected.")
|
raise RuntimeError("Too many NaN losses detected.")
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
"""
|
||||||
|
|
||||||
self.backward(loss)
|
self.backward(loss)
|
||||||
self.step()
|
self.step()
|
||||||
|
|
|
@ -101,9 +101,6 @@ def train_feeder(engine, batch, teacher=None):
|
||||||
|
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
if torch.isnan(loss).any():
|
|
||||||
return
|
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
stats |= {k: v.item() for k, v in stat.items()}
|
stats |= {k: v.item() for k, v in stat.items()}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user