skip step on nan loss (ironically I have not had a nan loss after adding this), throw exception with invalid cfg.dataset.sample_type and sample_order combination (because I was tricked by this in my yaml and had inconsistent vram usage)

This commit is contained in:
mrq 2024-11-01 20:54:53 -05:00
parent fb8faa295b
commit ef1c17430f
2 changed files with 11 additions and 4 deletions

View File

@ -675,6 +675,10 @@ class Dataset(_Dataset):
# this just makes it be happy
if len(self.dataset) == 0:
self.dataset = cfg.dataset.training
# hard error because I kept getting tricked by this myself
if self.sampler_order == "duration" and self.sampler_type != "path":
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
# dict of paths keyed by speaker names
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)

View File

@ -64,6 +64,7 @@ class Engine(DeepSpeedEngine):
self.max_nan_losses = 8
self.current_batch_size = 0
self.skip_on_nan = True
def freeze(self, freeze_all=True):
# freeze non-LoRA params if requested
@ -142,14 +143,16 @@ class Engine(DeepSpeedEngine):
losses = self.gather_attribute("loss")
loss = torch.stack([*losses.values()]).sum()
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
if torch.isnan(loss).any():
self.max_nan_losses = self.max_nan_losses - 1
if self.max_nan_losses < 0:
raise RuntimeError("Too many NaN losses detected.")
stats = {}
stats |= {k: v.item() for k, v in losses.items()}
stats |= self.gather_attribute("scalar")
return stats
self.backward(loss)
self.step()