set NaN alarm to show the iteration it happened it

This commit is contained in:
mrq 2023-03-07 19:22:11 +00:00
parent c27ee3ce95
commit 3718e9d0fb

View File

@ -802,8 +802,8 @@ class TrainingState():
if line.find('INFO: [epoch:') >= 0: if line.find('INFO: [epoch:') >= 0:
info_line = line.split("INFO:")[-1] info_line = line.split("INFO:")[-1]
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in info_line: if ': nan' in info_line and not self.self.nan_detected:
self.nan_detected = True self.nan_detected = self.it
# easily rip out our stats... # easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
@ -962,7 +962,7 @@ class TrainingState():
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]" message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
if self.nan_detected: if self.nan_detected:
message = f"[!NaN DETECTED!] {message}" message = f"[!NaN DETECTED! {self.nan_detected}] {message}"
if message: if message:
percent = self.it / float(self.its) # self.epoch / float(self.epochs) percent = self.it / float(self.its) # self.epoch / float(self.epochs)