diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index b17f47c..03f5233 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -273,14 +273,15 @@ class Engine(): losses = self.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() + stats = {} + stats |= {k: v.item() for k, v in losses.items()} + stats |= self.gather_attribute("scalar") + if torch.isnan(loss).any(): self.max_nan_losses = self.max_nan_losses - 1 if self.max_nan_losses < 0: raise RuntimeError("Too many NaN losses detected.") - - stats = {} - stats |= {k: v.item() for k, v in losses.items()} - stats |= self.gather_attribute("scalar") + return stats self.backward(loss) self.step() @@ -480,42 +481,32 @@ class Engines(dict[str, Engine]): start_time = time.time() - tries = 4 - n_ooms = torch.zeros([], device=device) - batch = to_device(batch, device) + n_ooms = torch.zeros([], device=device) if not cfg.trainer.check_for_oom: res = feeder( engine=engine, batch=batch ) else: - while tries >= 0: - try: - res = feeder( engine=engine, batch=batch ) - break - except RuntimeError as e: - _logger.error(f"Forward: {str(e)}") + try: + res = feeder( engine=engine, batch=batch ) + except RuntimeError as e: + _logger.error(f"Forward: {str(e)}") - if "out of memory" not in str(e): - self.save_checkpoint() - raise e + if "out of memory" not in str(e): + self.save_checkpoint() + raise e - # shrink batch size until it's happy - for k in batch: - batch[k] = batch[k][:-1] - - if tries <= 0: - # trigger OOM - n_ooms += 1 - else: - # also do GC - do_gc() - continue + n_ooms += 1 if world_size() > 1: all_reduce(n_ooms) + if n_ooms.item() > 0: + continue + """ self.save_checkpoint() raise RuntimeError("Out of memory during forward pass!") + """ if res is None: continue @@ -523,8 +514,6 @@ class Engines(dict[str, Engine]): loss, engine_stats = res engine_stats |= self.gather_attribute("scalar") - n_ooms = torch.zeros([], device=device) - if not cfg.trainer.check_for_oom: engine.backward(loss) else: @@ -545,8 +534,7 @@ class Engines(dict[str, Engine]): if n_ooms.item() > 0: self.save_checkpoint() - - raise RuntimeError("Out of memory during backwards pass!") + raise RuntimeError("Out of memory during backwards pass!") engine.step() diff --git a/vall_e/train.py b/vall_e/train.py index 0f7c7ee..8f449bc 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -47,6 +47,9 @@ def train_feeder(engine, batch): loss = torch.stack([*losses.values()]).sum() + if torch.isnan(loss).any(): + return + stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in stat.items()}