skip step on nan loss (ironically I have not had a nan loss after adding this), throw exception with invalid cfg.dataset.sample_type and sample_order combination (because I was tricked by this in my yaml and had inconsistent vram usage)
This commit is contained in:
parent
fb8faa295b
commit
ef1c17430f
|
@ -675,6 +675,10 @@ class Dataset(_Dataset):
|
|||
# this just makes it be happy
|
||||
if len(self.dataset) == 0:
|
||||
self.dataset = cfg.dataset.training
|
||||
|
||||
# hard error because I kept getting tricked by this myself
|
||||
if self.sampler_order == "duration" and self.sampler_type != "path":
|
||||
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
||||
|
||||
# dict of paths keyed by speaker names
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||
|
|
|
@ -64,6 +64,7 @@ class Engine(DeepSpeedEngine):
|
|||
|
||||
self.max_nan_losses = 8
|
||||
self.current_batch_size = 0
|
||||
self.skip_on_nan = True
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
# freeze non-LoRA params if requested
|
||||
|
@ -142,14 +143,16 @@ class Engine(DeepSpeedEngine):
|
|||
losses = self.gather_attribute("loss")
|
||||
loss = torch.stack([*losses.values()]).sum()
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
||||
if torch.isnan(loss).any():
|
||||
self.max_nan_losses = self.max_nan_losses - 1
|
||||
if self.max_nan_losses < 0:
|
||||
raise RuntimeError("Too many NaN losses detected.")
|
||||
|
||||
stats = {}
|
||||
stats |= {k: v.item() for k, v in losses.items()}
|
||||
stats |= self.gather_attribute("scalar")
|
||||
|
||||
return stats
|
||||
|
||||
self.backward(loss)
|
||||
self.step()
|
||||
|
|
Loading…
Reference in New Issue
Block a user