From d312019d0573746cbc833df7397b90defc5093cd Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Sun, 5 Mar 2023 07:37:27 +0000
Subject: [PATCH] reordered things so it uses fresh data and not last-updated
 data

---
 src/utils.py | 181 ++++++++++++++++++++++++++-------------------------
 1 file changed, 93 insertions(+), 88 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 5c1c449..90a46b4 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -552,6 +552,11 @@ class TrainingState():
 		self.last_info_check_at = 0
 		self.statistics = []
 		self.losses = []
+		self.metrics = {
+			'step': "",
+			'rate': "",
+			'loss': "",
+		}
 
 		self.loss_milestones = [ 1.0, 0.15, 0.05 ]
 
@@ -691,7 +696,37 @@ class TrainingState():
 			lapsed = False
 
 			message = None
-			if line.find('%|') > 0:
+			if line.find('INFO: [epoch:') >= 0:
+				# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
+				if ': nan' in line:
+					should_return = True
+
+					print("! NAN DETECTED !")
+					self.buffer.append("! NAN DETECTED !")
+
+				# easily rip out our stats...
+				match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
+				if match and len(match) > 0:
+					for k, v in match:
+						self.info[k] = float(v.replace(",", ""))
+
+				self.load_losses(update=True)
+				should_return = True
+
+			elif line.find('Saving models and training states') >= 0:
+				self.checkpoint = self.checkpoint + 1
+
+				percent = self.checkpoint / float(self.checkpoints)
+				message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
+				if progress is not None:
+					progress(percent, message)
+
+				print(f'{"{:.3f}".format(percent*100)}% {message}')
+				self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
+
+				self.cleanup_old(keep=keep_x_past_datasets)
+
+			elif line.find('%|') > 0:
 				match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
 				if match and len(match) > 0:
 					match = match[0]
@@ -722,63 +757,8 @@ class TrainingState():
 						except Exception as e:
 							pass
 
-					metric_step = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
-					metric_step = ", ".join(metric_step)
-
-					metric_rate = []
-					if self.epoch_rate:
-						metric_rate.append(self.epoch_rate)
-					if self.it_rate:
-						metric_rate.append(self.it_rate)
-					metric_rate = ", ".join(metric_rate)
-
-					eta_hhmmss = "?"
-					if self.eta_hhmmss:
-						eta_hhmmss = self.eta_hhmmss
-					else:
-						try:
-							eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
-							eta = str(timedelta(seconds=int(eta)))
-							eta_hhmmss = eta
-						except Exception as e:
-							pass
-					
-					metric_loss = []
-					if len(self.losses) > 0:
-						metric_loss.append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
-
-					if len(self.losses) >= 2:
-						# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
-						d1_loss = self.losses[-1]["value"]
-						d2_loss = self.losses[-2]["value"]
-						dloss = d2_loss - d1_loss
-						
-						d1_step = self.losses[-1]["step"]
-						d2_step = self.losses[-2]["step"]
-						dstep = d2_step - d1_step
-
-						# don't bother if the loss went up
-						if dloss < 0:
-							its_remain = self.its - self.it
-							inst_deriv = dloss / dstep
-
-							next_milestone = None
-							for milestone in self.loss_milestones:
-								if d1_loss > milestone:
-									next_milestone = milestone
-									break
-
-							if next_milestone:
-								# tfw can do simple calculus but not basic algebra in my head
-								est_its = (next_milestone - d1_loss) * (dstep / dloss)
-								metric_loss.append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
-							else:
-								est_loss = inst_deriv * its_remain + d1_loss
-								metric_loss.append(f'Est. final loss: {"{:3f}".format(est_loss)}')
-
-					metric_loss = ", ".join(metric_loss)
-
-					message = f'[{metric_step}] [{metric_rate}] [ETA: {eta_hhmmss}] [{metric_loss}]'
+					self.metrics['step'] = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
+					self.metrics['step'] = ", ".join(self.metrics['step'])
 
 			if lapsed:
 				self.epoch = self.epoch + 1
@@ -799,6 +779,61 @@ class TrainingState():
 				except Exception as e:
 					pass
 
+			self.metrics['rate'] = []
+			if self.epoch_rate:
+				self.metrics['rate'].append(self.epoch_rate)
+			if self.it_rate:
+				self.metrics['rate'].append(self.it_rate)
+			self.metrics['rate'] = ", ".join(self.metrics['rate'])
+
+			eta_hhmmss = "?"
+			if self.eta_hhmmss:
+				eta_hhmmss = self.eta_hhmmss
+			else:
+				try:
+					eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
+					eta = str(timedelta(seconds=int(eta)))
+					eta_hhmmss = eta
+				except Exception as e:
+					pass
+			
+			self.metrics['loss'] = []
+			if len(self.losses) > 0:
+				self.metrics['loss'].append(f'Loss: {"{:3f}".format(self.losses[-1]["value"])}')
+
+			if len(self.losses) >= 2:
+				# i can probably do a """riemann sum""" to get a better derivative, but the instantaneous one works fine
+				d1_loss = self.losses[-1]["value"]
+				d2_loss = self.losses[-2]["value"]
+				dloss = d2_loss - d1_loss
+				
+				d1_step = self.losses[-1]["step"]
+				d2_step = self.losses[-2]["step"]
+				dstep = d2_step - d1_step
+
+				# don't bother if the loss went up
+				if dloss < 0:
+					its_remain = self.its - self.it
+					inst_deriv = dloss / dstep
+
+					next_milestone = None
+					for milestone in self.loss_milestones:
+						if d1_loss > milestone:
+							next_milestone = milestone
+							break
+
+					if next_milestone:
+						# tfw can do simple calculus but not basic algebra in my head
+						est_its = (next_milestone - d1_loss) * (dstep / dloss)
+						self.metrics['loss'].append(f'Est. milestone {next_milestone} in: {int(est_its)}its')
+					else:
+						est_loss = inst_deriv * its_remain + d1_loss
+						self.metrics['loss'].append(f'Est. final loss: {"{:3f}".format(est_loss)}')
+
+			self.metrics['loss'] = ", ".join(self.metrics['loss'])
+
+			message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
+
 			if message:
 				percent = self.it / float(self.its) # self.epoch / float(self.epochs)
 				if progress is not None:
@@ -806,36 +841,6 @@ class TrainingState():
 
 				self.buffer.append(f'[{"{:.3f}".format(percent*100)}%] {message}')
 
-			if line.find('INFO: [epoch:') >= 0:
-				# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
-				if ': nan' in line:
-					should_return = True
-
-					print("! NAN DETECTED !")
-					self.buffer.append("! NAN DETECTED !")
-
-				# easily rip out our stats...
-				match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
-				if match and len(match) > 0:
-					for k, v in match:
-						self.info[k] = float(v.replace(",", ""))
-
-				self.load_losses(update=True)
-				should_return = True
-
-			elif line.find('Saving models and training states') >= 0:
-				self.checkpoint = self.checkpoint + 1
-
-				percent = self.checkpoint / float(self.checkpoints)
-				message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
-				if progress is not None:
-					progress(percent, message)
-
-				print(f'{"{:.3f}".format(percent*100)}% {message}')
-				self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
-
-				self.cleanup_old(keep=keep_x_past_datasets)
-
 		if verbose and not self.training_started:
 			should_return = True