remove nan checks because it causes problems in distributed training because I'm not syncing between GPUs (and nan losses gets ignored anyways with loss scaling)
This commit is contained in:
parent
2ba6b483dc
commit
4800e7179a
|
@ -76,7 +76,6 @@ class Engine():
|
|||
|
||||
self._frozen_params = set()
|
||||
|
||||
self.max_nan_losses = 8
|
||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||
|
||||
self.current_batch_size = 0
|
||||
|
@ -294,12 +293,6 @@ class Engine():
|
|||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
return stats
|
||||
|
||||
self.backward(loss)
|
||||
self.step()
|
||||
|
||||
|
@ -545,12 +538,16 @@ class Engines(dict[str, Engine]):
|
|||
raise RuntimeError("Out of memory during forward pass!")
|
||||
"""
|
||||
|
||||
# this causes problems in distributed training
|
||||
# it's probably required to do all_reduce for nan checks
|
||||
"""
|
||||
# no results are returned when a nan is encountered, so catch it here too
|
||||
if res is None:
|
||||
engine.max_nan_losses = engine.max_nan_losses - 1
|
||||
if engine.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
continue
|
||||
"""
|
||||
|
||||
loss, engine_stats = res
|
||||
engine_stats |= self.gather_attribute("scalar")
|
||||
|
|
|
@ -62,9 +62,7 @@ class Engine(DeepSpeedEngine):
|
|||
self.global_samples = stats["global_samples"]
|
||||
self.tokens_processed = stats["tokens_processed"]
|
||||
|
||||
self.max_nan_losses = 8
|
||||
self.current_batch_size = 0
|
||||
self.skip_on_nan = True
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
# freeze non-LoRA params if requested
|
||||
|
@ -151,12 +149,14 @@ class Engine(DeepSpeedEngine):
|
|||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
||||
"""
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
|
||||
return stats
|
||||
"""
|
||||
|
||||
self.backward(loss)
|
||||
self.step()
|
||||
|
|
|
@ -101,9 +101,6 @@ def train_feeder(engine, batch, teacher=None):
|
|||
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
return
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
|
Loading…
Reference in New Issue
Block a user