From 8b9c9e1bbf4ef8fb03364942746954dc64fe8570 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Sun, 5 Mar 2023 18:53:12 +0000
Subject: [PATCH] remove redundant stats, add showing LR

---
 src/utils.py | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index fd0851c..69ad56d 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -758,7 +758,11 @@ class TrainingState():
 						except Exception as e:
 							pass
 
-					self.metrics['step'] = [f"{self.epoch}/{self.epochs}", f"{self.it}/{self.its}", f"{step}/{steps}"]
+					self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
+					if self.epochs != self.its:
+						self.metric.append(f"{self.it}/{self.its}")
+					if steps > 1:
+						self.metric.append(f"{step}/{steps}")
 					self.metrics['step'] = ", ".join(self.metrics['step'])
 
 			if lapsed:
@@ -786,7 +790,7 @@ class TrainingState():
 			self.metrics['rate'] = []
 			if self.epoch_rate:
 				self.metrics['rate'].append(self.epoch_rate)
-			if self.it_rate:
+			if self.it_rate and self.epoch_rate != self.it_rate:
 				self.metrics['rate'].append(self.it_rate)
 			self.metrics['rate'] = ", ".join(self.metrics['rate'])
 
@@ -802,6 +806,10 @@ class TrainingState():
 					pass
 			
 			self.metrics['loss'] = []
+
+			if 'learning_rate_gpt_0' in self.info:
+				self.metrics['loss'].append(f'LR: {"{:.9f}".format(self.info["learning_rate_gpt_0"])}')
+
 			if len(self.losses) > 0:
 				self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
 
@@ -845,7 +853,7 @@ class TrainingState():
 
 			self.metrics['loss'] = ", ".join(self.metrics['loss'])
 
-			message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}] [{self.metrics['loss']}]"
+			message = f"[{self.metrics['step']}] [{self.metrics['rate']}] [ETA: {eta_hhmmss}]\n[{self.metrics['loss']}]"
 
 			if message:
 				percent = self.it / float(self.its) # self.epoch / float(self.epochs)
@@ -949,6 +957,7 @@ def stop_training():
 	try:
 		children = [p.info for p in psutil.process_iter(attrs=['pid', 'name', 'cmdline']) if './src/train.py' in p.info['cmdline']]
 	except Exception as e:
+		print(e)
 		pass
 
 	training_state.process.stdout.close()