diff --git a/src/utils.py b/src/utils.py index 9b3e5a0..8424c6c 100755 --- a/src/utils.py +++ b/src/utils.py @@ -698,15 +698,16 @@ class TrainingState(): message = None if line.find('INFO: [epoch:') >= 0: + info_line = line.split("INFO:")[-1] # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point - if ': nan' in line: + if ': nan' in info_line: should_return = True print("! NAN DETECTED !") self.buffer.append("! NAN DETECTED !") # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) + match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line) if match and len(match) > 0: for k, v in match: self.info[k] = float(v.replace(",", "")) @@ -714,6 +715,11 @@ class TrainingState(): self.load_losses(update=True) should_return = True + if 'epoch' in self.info: + self.epoch = int(self.info['epoch']) + if 'iter' in self.info: + self.it = int(self.info['iter']) + elif line.find('Saving models and training states') >= 0: self.checkpoint = self.checkpoint + 1 @@ -727,7 +733,7 @@ class TrainingState(): self.cleanup_old(keep=keep_x_past_datasets) - elif line.find('%|') > 0: + if line.find('%|') > 0: match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) if match and len(match) > 0: match = match[0] @@ -839,10 +845,10 @@ class TrainingState(): if deriv != 0: # dloss < 0: next_milestone = None for milestone in self.loss_milestones: - if loss_value < milestone: + if loss_value > milestone: next_milestone = milestone break - + if next_milestone: # tfw can do simple calculus but not basic algebra in my head est_its = (next_milestone - loss_value) / deriv