From d312019d0573746cbc833df7397b90defc5093cd Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 5 Mar 2023 07:37:27 +0000 Subject: [PATCH] reordered things so it uses fresh data and not last-updated data --- src/utils.py | 181 ++++++++++++++++++++++++++------------------------- 1 file changed, 93 insertions(+), 88 deletions(-) diff --git a/src/utils.py b/src/utils.py index 5c1c449..90a46b4 100755 --- a/src/utils.py +++ b/src/utils.py @@ -552,6 +552,11 @@ class TrainingState(): self.last_info_check_at = 0 self.statistics = [] self.losses = [] + self.metrics = { + 'step': "", + 'rate': "", + 'loss': "", + } self.loss_milestones = [ 1.0, 0.15, 0.05 ] @@ -691,7 +696,37 @@ class TrainingState(): lapsed = False message = None - if line.find('%|') > 0: + if line.find('INFO: [epoch:') >= 0: + # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point + if ': nan' in line: + should_return = True + + print("! NAN DETECTED !") + self.buffer.append("! NAN DETECTED !") + + # easily rip out our stats... + match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) + if match and len(match) > 0: + for k, v in match: + self.info[k] = float(v.replace(",", "")) + + self.load_losses(update=True) + should_return = True + + elif line.find('Saving models and training states') >= 0: + self.checkpoint = self.checkpoint + 1 + + percent = self.checkpoint / float(self.checkpoints) + message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...' + if progress is not None: + progress(percent, message) + + print(f'{"{:.3f}".format(percent*100)}% {message}') + self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') + + self.cleanup_old(keep=keep_x_past_datasets) + + elif line.find('%|') > 0: match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line) if match and len(match) > 0: match = match[0] @@ -722,63 +757,8 @@ class TrainingState(): except Exception as e: pass - metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"] - metric_step = ", ".join(metric_step) - - metric_rate = [] - if self.epoch_rate: - metric_rate.append(self.epoch_rate) - if self.it_rate: - metric_rate.append(self.it_rate) - metric_rate = ", ".join(metric_rate) - - eta_hhmmss = "?" - if self.eta_hhmmss: - eta_hhmmss = self.eta_hhmmss - else: - try: - eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken) - eta = str(timedelta(seconds=int(eta))) - eta_hhmmss = eta - except Exception as e: - pass - - metric_loss = [] - if len(self.losses) > 0: - metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}') - - if len(self.losses) >= 2: - # i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine - d1_loss = self.losses[-1]["value"] - d2_loss = self.losses[-2]["value"] - dloss = d2_loss - d1_loss - - d1_step = self.losses[-1]["step"] - d2_step = self.losses[-2]["step"] - dstep = d2_step - d1_step - - # don't bother if the loss went up - if dloss < 0: - its_remain = self.its - self.it - inst_deriv = dloss / dstep - - next_milestone = None - for milestone in self.loss_milestones: - if d1_loss > milestone: - next_milestone = milestone - break - - if next_milestone: - # tfw can do simple calculus but not basic algebra in my head - est_its = (next_milestone - d1_loss) * (dstep / dloss) - metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its') - else: - est_loss = inst_deriv * its_remain + d1_loss - metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}') - - metric_loss = ", ".join(metric_loss) - - message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]' + self.metrics['step'] = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"] + self.metrics['step'] = ", ".join(self.metrics['step']) if lapsed: self.epoch = self.epoch + 1 @@ -799,6 +779,61 @@ class TrainingState(): except Exception as e: pass + self.metrics['rate'] = [] + if self.epoch_rate: + self.metrics['rate'].append(self.epoch_rate) + if self.it_rate: + self.metrics['rate'].append(self.it_rate) + self.metrics['rate'] = ", ".join(self.metrics['rate']) + + eta_hhmmss = "?" + if self.eta_hhmmss: + eta_hhmmss = self.eta_hhmmss + else: + try: + eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken) + eta = str(timedelta(seconds=int(eta))) + eta_hhmmss = eta + except Exception as e: + pass + + self.metrics['loss'] = [] + if len(self.losses) > 0: + self.metrics['loss'].append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}') + + if len(self.losses) >= 2: + # i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine + d1_loss = self.losses[-1]["value"] + d2_loss = self.losses[-2]["value"] + dloss = d2_loss - d1_loss + + d1_step = self.losses[-1]["step"] + d2_step = self.losses[-2]["step"] + dstep = d2_step - d1_step + + # don't bother if the loss went up + if dloss < 0: + its_remain = self.its - self.it + inst_deriv = dloss / dstep + + next_milestone = None + for milestone in self.loss_milestones: + if d1_loss > milestone: + next_milestone = milestone + break + + if next_milestone: + # tfw can do simple calculus but not basic algebra in my head + est_its = (next_milestone - d1_loss) * (dstep / dloss) + self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its') + else: + est_loss = inst_deriv * its_remain + d1_loss + self.metrics['loss'].append(f'Est. final loss: {"{:3f}".format(est_loss)}') + + self.metrics['loss'] = ", ".join(self.metrics['loss']) + + message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]" + if message: percent = self.it / float(self.its) # self.epoch / float(self.epochs) if progress is not None: @@ -806,36 +841,6 @@ class TrainingState(): self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}') - if line.find('INFO: [epoch:') >= 0: - # to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point - if ': nan' in line: - should_return = True - - print("! NAN DETECTED !") - self.buffer.append("! NAN DETECTED !") - - # easily rip out our stats... - match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line) - if match and len(match) > 0: - for k, v in match: - self.info[k] = float(v.replace(",", "")) - - self.load_losses(update=True) - should_return = True - - elif line.find('Saving models and training states') >= 0: - self.checkpoint = self.checkpoint + 1 - - percent = self.checkpoint / float(self.checkpoints) - message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...' - if progress is not None: - progress(percent, message) - - print(f'{"{:.3f}".format(percent*100)}% {message}') - self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}') - - self.cleanup_old(keep=keep_x_past_datasets) - if verbose and not self.training_started: should_return = True