From 856545f8bbc0cf1dee38edf1c47b6cd3cd3b520f Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 May 2024 22:23:29 -0500 Subject: [PATCH] nan loss detection (should have added it earlier), loss scaling for local backend + fp16 --- vall_e/config.py | 7 +++++++ vall_e/engines/base.py | 17 ++++++++++++++++- vall_e/engines/deepspeed.py | 7 +++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/vall_e/config.py b/vall_e/config.py index 128cbf1..f9ced76 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -544,6 +544,13 @@ class Trainer: return torch.float8_e4m3fn return torch.float32 + @cached_property + def scale_loss(self): + # currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways) + if self.backend != "local": + return False + return self.dtype == torch.float16 + @dataclass() class Inference: diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index de598e0..66bc8a3 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -67,6 +67,9 @@ class Engine(): self._frozen_params = set() + self.max_nan_losses = 8 + self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None + def freeze(self, freeze_all=True): # set to freeze if self._cfg is None or not hasattr(self._cfg, "frozen_params"): @@ -186,8 +189,11 @@ class Engine(): return self.module.forward(*args, **kwargs) def backward(self, loss): + if self.loss_scaler is not None: + return self.loss_scaler.scale(loss / self.gradient_accumulation_steps).backward() return (loss / self.gradient_accumulation_steps).backward() + def step(self): with torch.set_grad_enabled(self.gradient_accumulation_steps > 1): self.micro_steps += 1 @@ -197,7 +203,11 @@ class Engine(): torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping) self.global_steps += 1 - self.optimizer.step() + if self.loss_scaler is not None: + self.loss_scaler.step(self.optimizer) + self.loss_scaler.update() + else: + self.optimizer.step() self.optimizer.zero_grad() self._get_grad_norm() @@ -232,6 +242,11 @@ class Engine(): losses = self.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() + if torch.isnan(loss).any(): + self.max_nan_losses = self.max_nan_losses - 1 + if self.max_nan_losses < 0: + raise RuntimeError("Too many NaN losses detected.") + stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= self.gather_attribute("scalar") diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 585eced..1ef8333 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -60,6 +60,8 @@ class Engine(DeepSpeedEngine): self.global_samples = stats["global_samples"] self.tokens_processed = stats["tokens_processed"] + self.max_nan_losses = 8 + def freeze(self, freeze_all=True): if self._cfg is None or not hasattr(self._cfg, "frozen_params"): raise Exception("freeze_all=False yet self._cfg.frozen_params is None") @@ -113,6 +115,11 @@ class Engine(DeepSpeedEngine): losses = self.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() + if torch.isnan(loss).any(): + self.max_nan_losses = self.max_nan_losses - 1 + if self.max_nan_losses < 0: + raise RuntimeError("Too many NaN losses detected.") + stats = {} stats |= {k: v.item() for k, v in losses.items()} stats |= self.gather_attribute("scalar")