diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 537a364..1da2016 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -76,7 +76,6 @@ class Engine(): self._frozen_params = set() - self.max_nan_losses = 8 self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None self.current_batch_size = 0 @@ -294,12 +293,6 @@ class Engine(): stats |= {k: v.item() for k, v in losses.items()} stats |= self.gather_attribute("scalar") - if torch.isnan(loss).any(): - self.max_nan_losses = self.max_nan_losses - 1 - if self.max_nan_losses < 0: - raise RuntimeError("Too many NaN losses detected.") - return stats - self.backward(loss) self.step() @@ -545,12 +538,16 @@ class Engines(dict[str, Engine]): raise RuntimeError("Out of memory during forward pass!") """ + # this causes problems in distributed training + # it's probably required to do all_reduce for nan checks + """ # no results are returned when a nan is encountered, so catch it here too if res is None: engine.max_nan_losses = engine.max_nan_losses - 1 if engine.max_nan_losses < 0: raise RuntimeError("Too many NaN losses detected.") continue + """ loss, engine_stats = res engine_stats |= self.gather_attribute("scalar") diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 4edc001..f31d500 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -62,9 +62,7 @@ class Engine(DeepSpeedEngine): self.global_samples = stats["global_samples"] self.tokens_processed = stats["tokens_processed"] - self.max_nan_losses = 8 self.current_batch_size = 0 - self.skip_on_nan = True def freeze(self, freeze_all=True): # freeze non-LoRA params if requested @@ -151,12 +149,14 @@ class Engine(DeepSpeedEngine): stats |= {k: v.item() for k, v in losses.items()} stats |= self.gather_attribute("scalar") + """ if torch.isnan(loss).any(): self.max_nan_losses = self.max_nan_losses - 1 if self.max_nan_losses < 0: raise RuntimeError("Too many NaN losses detected.") return stats + """ self.backward(loss) self.step() diff --git a/vall_e/train.py b/vall_e/train.py index 23ac137..69175ca 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -101,9 +101,6 @@ def train_feeder(engine, batch, teacher=None): loss = torch.stack([*losses.values()]).sum() - if torch.isnan(loss).any(): - return - stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= {k: v.item() for k, v in stat.items()}