ughh
This commit is contained in:
parent
ec79230965
commit
62fe5b0943
|
@ -273,14 +273,15 @@ class Engine():
|
||||||
losses = self.gather_attribute("loss")
|
losses = self.gather_attribute("loss")
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
stats |= self.gather_attribute("scalar")
|
||||||
|
|
||||||
if torch.isnan(loss).any():
|
if torch.isnan(loss).any():
|
||||||
self.max_nan_losses = self.max_nan_losses - 1
|
self.max_nan_losses = self.max_nan_losses - 1
|
||||||
if self.max_nan_losses < 0:
|
if self.max_nan_losses < 0:
|
||||||
raise RuntimeError("Too many NaN losses detected.")
|
raise RuntimeError("Too many NaN losses detected.")
|
||||||
|
return stats
|
||||||
stats = {}
|
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
|
||||||
stats |= self.gather_attribute("scalar")
|
|
||||||
|
|
||||||
self.backward(loss)
|
self.backward(loss)
|
||||||
self.step()
|
self.step()
|
||||||
|
@ -480,42 +481,32 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
tries = 4
|
|
||||||
n_ooms = torch.zeros([], device=device)
|
|
||||||
|
|
||||||
batch = to_device(batch, device)
|
batch = to_device(batch, device)
|
||||||
|
n_ooms = torch.zeros([], device=device)
|
||||||
|
|
||||||
if not cfg.trainer.check_for_oom:
|
if not cfg.trainer.check_for_oom:
|
||||||
res = feeder( engine=engine, batch=batch )
|
res = feeder( engine=engine, batch=batch )
|
||||||
else:
|
else:
|
||||||
while tries >= 0:
|
try:
|
||||||
try:
|
res = feeder( engine=engine, batch=batch )
|
||||||
res = feeder( engine=engine, batch=batch )
|
except RuntimeError as e:
|
||||||
break
|
_logger.error(f"Forward: {str(e)}")
|
||||||
except RuntimeError as e:
|
|
||||||
_logger.error(f"Forward: {str(e)}")
|
|
||||||
|
|
||||||
if "out of memory" not in str(e):
|
if "out of memory" not in str(e):
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# shrink batch size until it's happy
|
n_ooms += 1
|
||||||
for k in batch:
|
|
||||||
batch[k] = batch[k][:-1]
|
|
||||||
|
|
||||||
if tries <= 0:
|
|
||||||
# trigger OOM
|
|
||||||
n_ooms += 1
|
|
||||||
else:
|
|
||||||
# also do GC
|
|
||||||
do_gc()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if world_size() > 1:
|
if world_size() > 1:
|
||||||
all_reduce(n_ooms)
|
all_reduce(n_ooms)
|
||||||
|
|
||||||
if n_ooms.item() > 0:
|
if n_ooms.item() > 0:
|
||||||
|
continue
|
||||||
|
"""
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
raise RuntimeError("Out of memory during forward pass!")
|
raise RuntimeError("Out of memory during forward pass!")
|
||||||
|
"""
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
continue
|
continue
|
||||||
|
@ -523,8 +514,6 @@ class Engines(dict[str, Engine]):
|
||||||
loss, engine_stats = res
|
loss, engine_stats = res
|
||||||
engine_stats |= self.gather_attribute("scalar")
|
engine_stats |= self.gather_attribute("scalar")
|
||||||
|
|
||||||
n_ooms = torch.zeros([], device=device)
|
|
||||||
|
|
||||||
if not cfg.trainer.check_for_oom:
|
if not cfg.trainer.check_for_oom:
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
else:
|
else:
|
||||||
|
@ -545,8 +534,7 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
if n_ooms.item() > 0:
|
if n_ooms.item() > 0:
|
||||||
self.save_checkpoint()
|
self.save_checkpoint()
|
||||||
|
raise RuntimeError("Out of memory during backwards pass!")
|
||||||
raise RuntimeError("Out of memory during backwards pass!")
|
|
||||||
|
|
||||||
engine.step()
|
engine.step()
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,9 @@ def train_feeder(engine, batch):
|
||||||
|
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
|
if torch.isnan(loss).any():
|
||||||
|
return
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
stats |= {k: v.item() for k, v in stat.items()}
|
stats |= {k: v.item() for k, v in stat.items()}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user