diff --git a/src/utils.py b/src/utils.py index 9a00c1c..81265d3 100755 --- a/src/utils.py +++ b/src/utils.py @@ -802,8 +802,8 @@ class TrainingState(): if line.find('INFO: [epoch:') >= 0: info_line = line.split("INFO:")[-1] # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point - if ': nan' in info_line: - self.nan_detected = True + if ': nan' in info_line and not self.self.nan_detected: + self.nan_detected = self.it # easily rip out our stats... match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) @@ -962,7 +962,7 @@ class TrainingState(): message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]" if self.nan_detected: - message = f"[!NaN DETECTED!] {message}" + message = f"[!NaN DETECTED! {self.nan_detected}] {message}" if message: percent = self.it / float(self.its) # self.epoch / float(self.epochs)