(should) fix reported epoch metric desyncing from defacto metric, fixed finding next milestone from wrong sign because of 2AM brain
This commit is contained in:
parent
206a14fdbe
commit
287738a338
16
src/utils.py
16
src/utils.py
|
@ -698,15 +698,16 @@ class TrainingState():
|
||||||
|
|
||||||
message = None
|
message = None
|
||||||
if line.find('INFO: [epoch:') >= 0:
|
if line.find('INFO: [epoch:') >= 0:
|
||||||
|
info_line = line.split("INFO:")[-1]
|
||||||
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
||||||
if ': nan' in line:
|
if ': nan' in info_line:
|
||||||
should_return = True
|
should_return = True
|
||||||
|
|
||||||
print("! NAN DETECTED !")
|
print("! NAN DETECTED !")
|
||||||
self.buffer.append("! NAN DETECTED !")
|
self.buffer.append("! NAN DETECTED !")
|
||||||
|
|
||||||
# easily rip out our stats...
|
# easily rip out our stats...
|
||||||
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
|
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
for k, v in match:
|
for k, v in match:
|
||||||
self.info[k] = float(v.replace(",", ""))
|
self.info[k] = float(v.replace(",", ""))
|
||||||
|
@ -714,6 +715,11 @@ class TrainingState():
|
||||||
self.load_losses(update=True)
|
self.load_losses(update=True)
|
||||||
should_return = True
|
should_return = True
|
||||||
|
|
||||||
|
if 'epoch' in self.info:
|
||||||
|
self.epoch = int(self.info['epoch'])
|
||||||
|
if 'iter' in self.info:
|
||||||
|
self.it = int(self.info['iter'])
|
||||||
|
|
||||||
elif line.find('Saving models and training states') >= 0:
|
elif line.find('Saving models and training states') >= 0:
|
||||||
self.checkpoint = self.checkpoint + 1
|
self.checkpoint = self.checkpoint + 1
|
||||||
|
|
||||||
|
@ -727,7 +733,7 @@ class TrainingState():
|
||||||
|
|
||||||
self.cleanup_old(keep=keep_x_past_datasets)
|
self.cleanup_old(keep=keep_x_past_datasets)
|
||||||
|
|
||||||
elif line.find('%|') > 0:
|
if line.find('%|') > 0:
|
||||||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
match = match[0]
|
match = match[0]
|
||||||
|
@ -839,10 +845,10 @@ class TrainingState():
|
||||||
if deriv != 0: # dloss < 0:
|
if deriv != 0: # dloss < 0:
|
||||||
next_milestone = None
|
next_milestone = None
|
||||||
for milestone in self.loss_milestones:
|
for milestone in self.loss_milestones:
|
||||||
if loss_value < milestone:
|
if loss_value > milestone:
|
||||||
next_milestone = milestone
|
next_milestone = milestone
|
||||||
break
|
break
|
||||||
|
|
||||||
if next_milestone:
|
if next_milestone:
|
||||||
# tfw can do simple calculus but not basic algebra in my head
|
# tfw can do simple calculus but not basic algebra in my head
|
||||||
est_its = (next_milestone - loss_value) / deriv
|
est_its = (next_milestone - loss_value) / deriv
|
||||||
|
|
Loading…
Reference in New Issue
Block a user