diff --git a/vall_e/data.py b/vall_e/data.py index b3d67b8..440a709 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -675,6 +675,10 @@ class Dataset(_Dataset): # this just makes it be happy if len(self.dataset) == 0: self.dataset = cfg.dataset.training + + # hard error because I kept getting tricked by this myself + if self.sampler_order == "duration" and self.sampler_type != "path": + raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.') # dict of paths keyed by speaker names self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type) diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 4abe44f..04809f2 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -64,6 +64,7 @@ class Engine(DeepSpeedEngine): self.max_nan_losses = 8 self.current_batch_size = 0 + self.skip_on_nan = True def freeze(self, freeze_all=True): # freeze non-LoRA params if requested @@ -142,14 +143,16 @@ class Engine(DeepSpeedEngine): losses = self.gather_attribute("loss") loss = torch.stack([*losses.values()]).sum() + stats = {} + stats |= {k: v.item() for k, v in losses.items()} + stats |= self.gather_attribute("scalar") + if torch.isnan(loss).any(): self.max_nan_losses = self.max_nan_losses - 1 if self.max_nan_losses < 0: raise RuntimeError("Too many NaN losses detected.") - - stats = {} - stats |= {k: v.item() for k, v in losses.items()} - stats |= self.gather_attribute("scalar") + + return stats self.backward(loss) self.step()