(should) fix reported epoch metric desyncing from defacto metric, fixed finding next milestone from wrong sign because of 2AM brain

This commit is contained in:
mrq 2023-03-05 20:42:45 +00:00
parent 206a14fdbe
commit 287738a338

View File

@ -698,15 +698,16 @@ class TrainingState():
message = None message = None
if line.find('INFO: [epoch:') >= 0: if line.find('INFO: [epoch:') >= 0:
info_line = line.split("INFO:")[-1]
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in line: if ': nan' in info_line:
should_return = True should_return = True
print("! NAN DETECTED !") print("! NAN DETECTED !")
self.buffer.append("! NAN DETECTED !") self.buffer.append("! NAN DETECTED !")
# easily rip out our stats... # easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
if match and len(match) > 0: if match and len(match) > 0:
for k, v in match: for k, v in match:
self.info[k] = float(v.replace(",", "")) self.info[k] = float(v.replace(",", ""))
@ -714,6 +715,11 @@ class TrainingState():
self.load_losses(update=True) self.load_losses(update=True)
should_return = True should_return = True
if 'epoch' in self.info:
self.epoch = int(self.info['epoch'])
if 'iter' in self.info:
self.it = int(self.info['iter'])
elif line.find('Saving models and training states') >= 0: elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1 self.checkpoint = self.checkpoint + 1
@ -727,7 +733,7 @@ class TrainingState():
self.cleanup_old(keep=keep_x_past_datasets) self.cleanup_old(keep=keep_x_past_datasets)
elif line.find('%|') > 0: if line.find('%|') > 0:
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
if match and len(match) > 0: if match and len(match) > 0:
match = match[0] match = match[0]
@ -839,10 +845,10 @@ class TrainingState():
if deriv != 0: # dloss < 0: if deriv != 0: # dloss < 0:
next_milestone = None next_milestone = None
for milestone in self.loss_milestones: for milestone in self.loss_milestones:
if loss_value < milestone: if loss_value > milestone:
next_milestone = milestone next_milestone = milestone
break break
if next_milestone: if next_milestone:
# tfw can do simple calculus but not basic algebra in my head # tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - loss_value) / deriv est_its = (next_milestone - loss_value) / deriv