From b989123bd41c7c47e1274713602ca51abf4b5c7a Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Wed, 1 Mar 2023 19:32:11 +0000
Subject: [PATCH] leverage tensorboard to parse tb_logger files when starting
 training (it seems to give a nicer resolution of training data, need to see
 about reading it directly while training)

---
 src/utils.py | 91 +++++++++++++++++++++++++++++-----------------------
 src/webui.py | 12 +++----
 2 files changed, 56 insertions(+), 47 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index b553d6e..6a803a9 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -498,39 +498,57 @@ class TrainingState():
 		self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
 
 	def load_losses(self):
-		if not os.path.isdir(self.dataset_dir):
+		if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
 			return
+		try:
+			from tensorboard.backend.event_processing import event_accumulator
+			use_tensorboard = True
+		except Exception as e:
+			use_tensorboard = False
 
-		logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
-		infos = {}
-		for log in logs:
-			with open(log, 'r', encoding="utf-8") as f:
-				lines = f.readlines()
-				for line in lines:
-					if line.find('INFO: [epoch:') >= 0:
-						# easily rip out our stats...
-						match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
-						if not match or len(match) == 0:
-							continue
+		if use_tensorboard:
+			logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
+			infos = {}
+			for log in logs:
+				try:
+					ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
+					ea.Reload()
 
-						info = {}
-						for k, v in match:
-							info[k] = float(v.replace(",", ""))
+					keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
+					for key in keys:
+						scalar = ea.Scalars(key)
+						for s in scalar:
+							self.losses.append( { "step": s.step, "value": s.value, "type": key } )
+				except Exception as e:
+					print("Failed to parse event log:", log)
+					pass
 
-						if 'iter' in info:
-							it = info['iter']
-							infos[it] = info
+		else:
+			logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
+			infos = {}
+			for log in logs:
+				with open(log, 'r', encoding="utf-8") as f:
+					lines = f.readlines()
+					for line in lines:
+						if line.find('INFO: [epoch:') >= 0:
+							# easily rip out our stats...
+							match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
+							if not match or len(match) == 0:
+								continue
 
-		for k in infos:
-			if 'loss_gpt_total' in infos[k]:
-				# self.losses.append([ int(k), infos[k]['loss_text_ce'], infos[k]['loss_mel_ce'], infos[k]['loss_gpt_total'] ])
-				self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_text_ce'], "type": "text_ce" })
-				self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_mel_ce'], "type": "mel_ce" })
-				self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_gpt_total'], "type": "gpt_total" })
-				"""
-				self.losses['iteration'].append(int(k))
-				self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total'])
-				"""
+							info = {}
+							for k, v in match:
+								info[k] = float(v.replace(",", ""))
+
+							if 'iter' in info:
+								it = info['iter']
+								infos[it] = info
+
+			for k in infos:
+				if 'loss_gpt_total' in infos[k]:
+					self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" })
+					self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" })
+					self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" })
 
 	def cleanup_old(self, keep=2):
 		if keep <= 0:
@@ -606,7 +624,7 @@ class TrainingState():
 						pass
 					last_loss = ""
 					if len(self.losses) > 0:
-						last_loss = f'[Loss @ it. {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}]'
+						last_loss = f'[Loss @ it. {self.losses[-1]["step"]}: {self.losses[-1]["value"]}]'
 					message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] {last_loss} [ETA: {self.eta_hhmmss}]'
 
 			if lapsed:
@@ -645,18 +663,9 @@ class TrainingState():
 				if 'loss_gpt_total' in self.info:
 					self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
 
-					self.losses.append({ "iteration": self.it, "loss": self.info['loss_text_ce'], "type": "text_ce" })
-					self.losses.append({ "iteration": self.it, "loss": self.info['loss_mel_ce'], "type": "mel_ce" })
-					self.losses.append({ "iteration": self.it, "loss": self.info['loss_gpt_total'], "type": "gpt_total" })
-					"""
-					self.losses.append([int(k), self.info['loss_text_ce'], "loss_text_ce"])
-					self.losses.append([int(k), self.info['loss_mel_ce'], "loss_mel_ce"])
-					self.losses.append([int(k), self.info['loss_gpt_total'], "loss_gpt_total"])
-					"""
-					"""
-					self.losses['iteration'].append(self.it)
-					self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
-					"""
+					self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" })
+					self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" })
+					self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" })
 
 					should_return = True
 			elif line.find('Saving models and training states') >= 0:
diff --git a/src/webui.py b/src/webui.py
index 6958034..79153ef 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -380,7 +380,7 @@ def setup_gradio():
 					prompt = gr.Textbox(lines=1, label="Custom Emotion + Prompt (if selected)")
 					voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit
 					mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath" )
-					voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=64, value=1, step=1)
+					voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=128, value=1, step=1)
 					with gr.Row():
 						refresh_voices = gr.Button(value="Refresh Voice List")
 						recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents")
@@ -538,12 +538,12 @@ def setup_gradio():
 						training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
 						training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
 
-						training_loss_graph = gr.LinePlot(label="Loss Rates",
-							x="iteration",
-							y="loss",
-							title="Loss Rates",
+						training_loss_graph = gr.LinePlot(label="Training Metrics",
+							x="step",
+							y="value",
+							title="Training Metrics",
 							color="type",
-							tooltip=['iteration', 'loss', 'type'],
+							tooltip=['step', 'value', 'type'],
 							width=600,
 							height=350
 						)