This commit is contained in:
mrq 2024-11-01 22:36:48 -05:00
parent ec79230965
commit 62fe5b0943
2 changed files with 22 additions and 31 deletions

View File

@ -273,14 +273,15 @@ class Engine():
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
return stats
self.backward(loss)
self.step()
@ -480,42 +481,32 @@ class Engines(dict[str, Engine]):
start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=device)
batch = to_device(batch, device)
n_ooms = torch.zeros([], device=device)
if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch )
else:
while tries >= 0:
try:
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
_logger.error(f"Forward: {str(e)}")
try:
res = feeder( engine=engine, batch=batch )
except RuntimeError as e:
_logger.error(f"Forward: {str(e)}")
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
if "out of memory" not in str(e):
self.save_checkpoint()
raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1
else:
# also do GC
do_gc()
continue
n_ooms += 1
if world_size() > 1:
all_reduce(n_ooms)
if n_ooms.item() > 0:
continue
"""
self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!")
"""
if res is None:
continue
@ -523,8 +514,6 @@ class Engines(dict[str, Engine]):
loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=device)
if not cfg.trainer.check_for_oom:
engine.backward(loss)
else:
@ -545,8 +534,7 @@ class Engines(dict[str, Engine]):
if n_ooms.item() > 0:
self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!")
raise RuntimeError("Out of memory during backwards pass!")
engine.step()

View File

@ -47,6 +47,9 @@ def train_feeder(engine, batch):
loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
return
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()}