skip step on nan loss (ironically I have not had a nan loss after adding this), throw exception with invalid cfg.dataset.sample_type and sample_order combination (because I was tricked by this in my yaml and had inconsistent vram usage)
This commit is contained in:
parent
fb8faa295b
commit
ef1c17430f
@ -675,6 +675,10 @@ class Dataset(_Dataset):
|
|||||||
# this just makes it be happy
|
# this just makes it be happy
|
||||||
if len(self.dataset) == 0:
|
if len(self.dataset) == 0:
|
||||||
self.dataset = cfg.dataset.training
|
self.dataset = cfg.dataset.training
|
||||||
|
|
||||||
|
# hard error because I kept getting tricked by this myself
|
||||||
|
if self.sampler_order == "duration" and self.sampler_type != "path":
|
||||||
|
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
||||||
|
|
||||||
# dict of paths keyed by speaker names
|
# dict of paths keyed by speaker names
|
||||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||||
|
|||||||
@ -64,6 +64,7 @@ class Engine(DeepSpeedEngine):
|
|||||||
|
|
||||||
self.max_nan_losses = 8
|
self.max_nan_losses = 8
|
||||||
self.current_batch_size = 0
|
self.current_batch_size = 0
|
||||||
|
self.skip_on_nan = True
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
# freeze non-LoRA params if requested
|
# freeze non-LoRA params if requested
|
||||||
@ -142,14 +143,16 @@ class Engine(DeepSpeedEngine):
|
|||||||
losses = self.gather_attribute("loss")
|
losses = self.gather_attribute("loss")
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
|
stats = {}
|
||||||
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
stats |= self.gather_attribute("scalar")
|
||||||
|
|
||||||
if torch.isnan(loss).any():
|
if torch.isnan(loss).any():
|
||||||
self.max_nan_losses = self.max_nan_losses - 1
|
self.max_nan_losses = self.max_nan_losses - 1
|
||||||
if self.max_nan_losses < 0:
|
if self.max_nan_losses < 0:
|
||||||
raise RuntimeError("Too many NaN losses detected.")
|
raise RuntimeError("Too many NaN losses detected.")
|
||||||
|
|
||||||
stats = {}
|
return stats
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
|
||||||
stats |= self.gather_attribute("scalar")
|
|
||||||
|
|
||||||
self.backward(loss)
|
self.backward(loss)
|
||||||
self.step()
|
self.step()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user