ughh
This commit is contained in:
parent
ec79230965
commit
62fe5b0943
|
@ -273,14 +273,15 @@ class Engine():
|
|||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
return stats
|
||||
|
||||
self.backward(loss)
|
||||
self.step()
|
||||
|
@ -480,42 +481,32 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
start_time = time.time()
|
||||
|
||||
tries = 4
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
batch = to_device(batch, device)
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if not cfg.trainer.check_for_oom:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
else:
|
||||
while tries >= 0:
|
||||
try:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
break
|
||||
except RuntimeError as e:
|
||||
_logger.error(f"Forward: {str(e)}")
|
||||
try:
|
||||
res = feeder( engine=engine, batch=batch )
|
||||
except RuntimeError as e:
|
||||
_logger.error(f"Forward: {str(e)}")
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
raise e
|
||||
|
||||
# shrink batch size until it's happy
|
||||
for k in batch:
|
||||
batch[k] = batch[k][:-1]
|
||||
|
||||
if tries <= 0:
|
||||
# trigger OOM
|
||||
n_ooms += 1
|
||||
else:
|
||||
# also do GC
|
||||
do_gc()
|
||||
continue
|
||||
n_ooms += 1
|
||||
|
||||
if world_size() > 1:
|
||||
all_reduce(n_ooms)
|
||||
|
||||
if n_ooms.item() > 0:
|
||||
continue
|
||||
"""
|
||||
self.save_checkpoint()
|
||||
raise RuntimeError("Out of memory during forward pass!")
|
||||
"""
|
||||
|
||||
if res is None:
|
||||
continue
|
||||
|
@ -523,8 +514,6 @@ class Engines(dict[str, Engine]):
|
|||
loss, engine_stats = res
|
||||
engine_stats |= self.gather_attribute("scalar")
|
||||
|
||||
n_ooms = torch.zeros([], device=device)
|
||||
|
||||
if not cfg.trainer.check_for_oom:
|
||||
engine.backward(loss)
|
||||
else:
|
||||
|
@ -545,8 +534,7 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
if n_ooms.item() > 0:
|
||||
self.save_checkpoint()
|
||||
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
raise RuntimeError("Out of memory during backwards pass!")
|
||||
|
||||
engine.step()
|
||||
|
||||
|
|
|
@ -47,6 +47,9 @@ def train_feeder(engine, batch):
|
|||
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
return
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
|
Loading…
Reference in New Issue
Block a user