This commit is contained in:
mrq 2024-11-01 22:36:48 -05:00
parent ec79230965
commit 62fe5b0943
2 changed files with 22 additions and 31 deletions

View File

@ -273,14 +273,15 @@ class Engine():
losses = self.gather_attribute("loss") losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
if torch.isnan(loss).any(): if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1 self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0: if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.") raise RuntimeError("Too many NaN losses detected.")
return stats
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
self.backward(loss) self.backward(loss)
self.step() self.step()
@ -480,18 +481,14 @@ class Engines(dict[str, Engine]):
start_time = time.time() start_time = time.time()
tries = 4
n_ooms = torch.zeros([], device=device)
batch = to_device(batch, device) batch = to_device(batch, device)
n_ooms = torch.zeros([], device=device)
if not cfg.trainer.check_for_oom: if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch ) res = feeder( engine=engine, batch=batch )
else: else:
while tries >= 0:
try: try:
res = feeder( engine=engine, batch=batch ) res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e: except RuntimeError as e:
_logger.error(f"Forward: {str(e)}") _logger.error(f"Forward: {str(e)}")
@ -499,23 +496,17 @@ class Engines(dict[str, Engine]):
self.save_checkpoint() self.save_checkpoint()
raise e raise e
# shrink batch size until it's happy
for k in batch:
batch[k] = batch[k][:-1]
if tries <= 0:
# trigger OOM
n_ooms += 1 n_ooms += 1
else:
# also do GC
do_gc()
continue
if world_size() > 1: if world_size() > 1:
all_reduce(n_ooms) all_reduce(n_ooms)
if n_ooms.item() > 0: if n_ooms.item() > 0:
continue
"""
self.save_checkpoint() self.save_checkpoint()
raise RuntimeError("Out of memory during forward pass!") raise RuntimeError("Out of memory during forward pass!")
"""
if res is None: if res is None:
continue continue
@ -523,8 +514,6 @@ class Engines(dict[str, Engine]):
loss, engine_stats = res loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar") engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=device)
if not cfg.trainer.check_for_oom: if not cfg.trainer.check_for_oom:
engine.backward(loss) engine.backward(loss)
else: else:
@ -545,7 +534,6 @@ class Engines(dict[str, Engine]):
if n_ooms.item() > 0: if n_ooms.item() > 0:
self.save_checkpoint() self.save_checkpoint()
raise RuntimeError("Out of memory during backwards pass!") raise RuntimeError("Out of memory during backwards pass!")
engine.step() engine.step()

View File

@ -47,6 +47,9 @@ def train_feeder(engine, batch):
loss = torch.stack([*losses.values()]).sum() loss = torch.stack([*losses.values()]).sum()
if torch.isnan(loss).any():
return
stats = {} stats = {}
stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in losses.items()}
stats |= {k: v.item() for k, v in stat.items()} stats |= {k: v.item() for k, v in stat.items()}