reordered things so it uses fresh data and not last-updated data

This commit is contained in:
mrq 2023-03-05 07:37:27 +00:00
parent ce3866d0cd
commit d312019d05

View File

@ -552,6 +552,11 @@ class TrainingState():
self.last_info_check_at = 0
self.statistics = []
self.losses = []
self.metrics = {
'step': "",
'rate': "",
'loss': "",
}
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
@ -691,7 +696,37 @@ class TrainingState():
lapsed = False
message = None
if line.find('%|') > 0:
if line.find('INFO: [epoch:') >= 0:
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in line:
should_return = True
print("! NAN DETECTED !")
self.buffer.append("! NAN DETECTED !")
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
if match and len(match) > 0:
for k, v in match:
self.info[k] = float(v.replace(",", ""))
self.load_losses(update=True)
should_return = True
elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1
percent = self.checkpoint / float(self.checkpoints)
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
if progress is not None:
progress(percent, message)
print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
self.cleanup_old(keep=keep_x_past_datasets)
elif line.find('%|') > 0:
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
if match and len(match) > 0:
match = match[0]
@ -722,15 +757,34 @@ class TrainingState():
except Exception as e:
pass
metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
metric_step = ", ".join(metric_step)
self.metrics['step'] = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
self.metrics['step'] = ", ".join(self.metrics['step'])
metric_rate = []
if lapsed:
self.epoch = self.epoch + 1
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
self.epoch_time_end = time.time()
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
self.epoch_time_start = time.time()
self.epoch_rate = f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' if self.epoch_time_delta >= 1 else f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' # I doubt anyone will have it/s rates, but its here
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta
self.epoch_taken = self.epoch_taken + 1
self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken)
try:
eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta
except Exception as e:
pass
self.metrics['rate'] = []
if self.epoch_rate:
metric_rate.append(self.epoch_rate)
self.metrics['rate'].append(self.epoch_rate)
if self.it_rate:
metric_rate.append(self.it_rate)
metric_rate = ", ".join(metric_rate)
self.metrics['rate'].append(self.it_rate)
self.metrics['rate'] = ", ".join(self.metrics['rate'])
eta_hhmmss = "?"
if self.eta_hhmmss:
@ -743,9 +797,9 @@ class TrainingState():
except Exception as e:
pass
metric_loss = []
self.metrics['loss'] = []
if len(self.losses) > 0:
metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
self.metrics['loss'].append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
if len(self.losses) >= 2:
# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
@ -771,33 +825,14 @@ class TrainingState():
if next_milestone:
# tfw can do simple calculus but not basic algebra in my head
est_its = (next_milestone - d1_loss) * (dstep / dloss)
metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
else:
est_loss = inst_deriv * its_remain + d1_loss
metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}')
self.metrics['loss'].append(f'Est. final loss: {"{:3f}".format(est_loss)}')
metric_loss = ", ".join(metric_loss)
self.metrics['loss'] = ", ".join(self.metrics['loss'])
message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
if lapsed:
self.epoch = self.epoch + 1
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
self.epoch_time_end = time.time()
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
self.epoch_time_start = time.time()
self.epoch_rate = f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' if self.epoch_time_delta >= 1 else f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' # I doubt anyone will have it/s rates, but its here
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta
self.epoch_taken = self.epoch_taken + 1
self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken)
try:
eta = str(timedelta(seconds=int(self.eta)))
self.eta_hhmmss = eta
except Exception as e:
pass
message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
if message:
percent = self.it / float(self.its) # self.epoch / float(self.epochs)
@ -806,36 +841,6 @@ class TrainingState():
self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
if line.find('INFO: [epoch:') >= 0:
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
if ': nan' in line:
should_return = True
print("! NAN DETECTED !")
self.buffer.append("! NAN DETECTED !")
# easily rip out our stats...
match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
if match and len(match) > 0:
for k, v in match:
self.info[k] = float(v.replace(",", ""))
self.load_losses(update=True)
should_return = True
elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1
percent = self.checkpoint / float(self.checkpoints)
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
if progress is not None:
progress(percent, message)
print(f'{"{:.3f}".format(percent*100)}% {message}')
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
self.cleanup_old(keep=keep_x_past_datasets)
if verbose and not self.training_started:
should_return = True