From 8b4da29d5fda2e07b5439f62917208b34423c4a3 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Sat, 25 Feb 2023 13:55:25 +0000
Subject: [PATCH] csome adjustments to the training output parser, now updates
 per iteration for really large batches (like the one I'm doing for a dataset
 size of 19420)

---
 src/utils.py | 123 +++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 90 insertions(+), 33 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index d8150ed..02798d9 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -470,14 +470,21 @@ class TrainingState():
 		self.epoch_rate = ""
 		self.epoch_time_start = 0
 		self.epoch_time_end = 0
+		
+		self.it_rate = ""
+		self.it_time_start = 0
+		self.it_time_end = 0
+		self.last_step = 0
+
 		self.eta = "?"
 		self.eta_hhmmss = "?"
 
 		print("Spawning process: ", " ".join(self.cmd))
 		self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
 
-	def parse(self, line, verbose=False, buffer_size=8, progress=None):
-		self.buffer.append(f'{line}')
+	def parse(self, line, verbose=False, buffer_size=8, progress=None, owner=True):
+		if owner:
+			self.buffer.append(f'{line}')
 
 		# rip out iteration info
 		if not self.training_started:
@@ -492,47 +499,97 @@ class TrainingState():
 				if match and len(match) > 0:
 					self.it = int(match[0].replace(",", ""))
 		else:
-			if line.find('%|') > 0 and not self.open_state:
-				self.open_state = True
-			elif line.find('100%|') == 0 and self.open_state:
-				self.open_state = False
-				self.epoch = self.epoch + 1
+			lapsed = line.find('100%|') == 0 and self.open_state
 
-				self.epoch_time_end = time.time()
-				self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
-				self.epoch_time_start = time.time()
-				self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here
-				self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
-				self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
+			if line.find('%|') > 0:
+				match = re.findall(r' +?(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
+				if match and len(match) > 0:
+					match = match[0]
+					percent = int(match[0])/100.0
+					progressbar = match[1]
+					step = int(match[2])
+					steps = int(match[3])
+					elapsed = match[4]
+					until = match[5]
+					rate = match[6]
+
+					epoch_percent = self.epoch / float(self.epochs)
+
+					if owner:
+						last_step = self.last_step
+						self.last_step = step
+						if last_step < step:
+							self.it = self.it + (step - last_step)
+
+						if last_step > step and step == 0:
+							lapsed = True
+
+						self.it_time_end = time.time()
+						self.it_time_delta = self.it_time_end-self.it_time_start
+						self.it_time_start = time.time()
+						self.it_rate = f'[{"{:.3f}".format(self.it_time_delta)}s/it]' if self.it_time_delta >= 1 else f'[{"{:.3f}".format(1/self.it_time_delta)}it/s]'
+						
+						self.eta = (self.its - self.it) * self.it_time_delta
+						self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
+
+					message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
+					if progress is not None:
+						progress(epoch_percent, message)
+					if owner:
+						# print(f'{"{:.3f}".format(percent*100)}% {message}')
+						self.buffer.append(f'[{"{:.3f}".format(epoch_percent*100)}% / {"{:.3f}".format(percent*100)}%] {message}')
+
+			if line.find('%|') > 0 and not self.open_state:
+				if owner:
+					self.open_state = True
+			elif lapsed:
+				if owner:
+					self.open_state = False
+					self.epoch = self.epoch + 1
+					self.it = int(self.epoch * (self.dataset_size / self.batch_size))
+	 
+					self.epoch_time_end = time.time()
+					self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
+					self.epoch_time_start = time.time()
+					self.epoch_rate = f'[{"{:.3f}".format(self.epoch_time_delta)}s/epoch]' if self.epoch_time_delta >= 1 else f'[{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s]' # I doubt anyone will have it/s rates, but its here
+					self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
+					self.eta_hhmmss = str(timedelta(seconds=int(self.eta)))
 
 				percent = self.epoch / float(self.epochs)
-				message = f'[{self.epoch}/{self.epochs}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} {self.status}'
-				print(f'{"{:.3f}".format(percent*100)}% {message}')
+				message = f'[{self.epoch}/{self.epochs}] [{self.it}/{self.its}] [ETA: {self.eta_hhmmss}] {self.epoch_rate} / {self.it_rate} {self.status}'
+				
 				if progress is not None:
 					progress(percent, message)
-				self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
+
+				if owner:
+					print(f'{"{:.3f}".format(percent*100)}% {message}')
+					self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
 
 			if line.find('INFO: [epoch:') >= 0:
-				# easily rip out our stats...
-				match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
-				if match and len(match) > 0:
-					for k, v in match:
-						self.info[k] = float(v)
-						
-				if 'loss_gpt_total' in self.info:
-					self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
-					print(self.status)
-					self.buffer.append(self.status)
+				if owner:
+					# easily rip out our stats...
+					match = re.findall(r'\b([a-z_0-9]+?)\b: ([0-9]\.[0-9]+?e[+-]\d+)\b', line)
+					if match and len(match) > 0:
+						for k, v in match:
+							self.info[k] = float(v)
+							
+					if 'loss_gpt_total' in self.info:
+						self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
+						print(self.status)
+						self.buffer.append(self.status)
 			elif line.find('Saving models and training states') >= 0:
-				self.checkpoint = self.checkpoint + 1
+				if owner:
+					self.checkpoint = self.checkpoint + 1
 				percent = self.checkpoint / float(self.checkpoints)
 				message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
-				print(f'{"{:.3f}".format(percent*100)}% {message}')
 				if progress is not None:
 					progress(percent, message)
-				self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
+				if owner:
+					print(f'{"{:.3f}".format(percent*100)}% {message}')
+					self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
 
-		self.buffer = self.buffer[-buffer_size:]
+		if owner:
+			self.buffer = self.buffer[-buffer_size:]
 		if verbose or not self.training_started:
 			return "".join(self.buffer)
 
@@ -552,7 +609,7 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
 
 	for line in iter(training_state.process.stdout.readline, ""):
 		
-		res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
+		res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True )
 		print(f"[Training] [{datetime.now().isoformat()}] {line[:-1]}")
 		if res:
 			yield res
@@ -565,13 +622,13 @@ def run_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress
 	#if return_code:
 	#	raise subprocess.CalledProcessError(return_code, cmd)
 
-def reconnect_training(config_path, verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
+def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
 	global training_state
 	if not training_state or not training_state.process:
 		return "Training not in progress"
 
 	for line in iter(training_state.process.stdout.readline, ""):
-		res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress )
+		res = training_state.parse( line=line, verbose=verbose, buffer_size=buffer_size, progress=progress, owner=True )
 		if res:
 			yield res