From 81eb58f0d6b2dbedbb0b479bf9bc1b9ba80b3f41 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Tue, 28 Feb 2023 06:18:18 +0000
Subject: [PATCH] show different losses, rewordings

---
 src/utils.py | 23 ++++++++++++++++++-----
 src/webui.py |  8 +++++---
 2 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/src/utils.py b/src/utils.py
index 3b081db..14aa533 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -482,10 +482,7 @@ class TrainingState():
 		self.eta = "?"
 		self.eta_hhmmss = "?"
 
-		self.losses = {
-			'iteration': [],
-			'loss_gpt_total': []
-		}
+		self.losses = []
 
 
 		self.load_losses()
@@ -522,8 +519,14 @@ class TrainingState():
 
 		for k in infos:
 			if 'loss_gpt_total' in infos[k]:
+				# self.losses.append([ int(k), infos[k]['loss_text_ce'], infos[k]['loss_mel_ce'], infos[k]['loss_gpt_total'] ])
+				self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_text_ce'], "type": "text_ce" })
+				self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_mel_ce'], "type": "mel_ce" })
+				self.losses.append({ "iteration": int(k), "loss": infos[k]['loss_gpt_total'], "type": "gpt_total" })
+				"""
 				self.losses['iteration'].append(int(k))
 				self.losses['loss_gpt_total'].append(infos[k]['loss_gpt_total'])
+				"""
 
 	def cleanup_old(self, keep=2):
 		if keep <= 0:
@@ -593,7 +596,7 @@ class TrainingState():
 					except Exception as e:
 						pass
 
-					message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses["iteration"][-1]}: {self.losses["loss_gpt_total"][-1]}] [ETA: {self.eta_hhmmss}]'
+					message = f'[{self.epoch}/{self.epochs}, {self.it}/{self.its}, {step}/{steps}] [{self.epoch_rate}, {self.it_rate}] [Loss at it {self.losses[-1]["iteration"]}: {self.losses[-1]["loss"]}] [ETA: {self.eta_hhmmss}]'
 
 			if lapsed:
 				self.epoch = self.epoch + 1
@@ -631,8 +634,18 @@ class TrainingState():
 				if 'loss_gpt_total' in self.info:
 					self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
 
+					self.losses.append({ "iteration": self.it, "loss": self.info['loss_text_ce'], "type": "text_ce" })
+					self.losses.append({ "iteration": self.it, "loss": self.info['loss_mel_ce'], "type": "mel_ce" })
+					self.losses.append({ "iteration": self.it, "loss": self.info['loss_gpt_total'], "type": "gpt_total" })
+					"""
+					self.losses.append([int(k), self.info['loss_text_ce'], "loss_text_ce"])
+					self.losses.append([int(k), self.info['loss_mel_ce'], "loss_mel_ce"])
+					self.losses.append([int(k), self.info['loss_gpt_total'], "loss_gpt_total"])
+					"""
+					"""
 					self.losses['iteration'].append(self.it)
 					self.losses['loss_gpt_total'].append(self.info['loss_gpt_total'])
+					"""
 
 					verbose = True
 			elif line.find('Saving models and training states') >= 0:
diff --git a/src/webui.py b/src/webui.py
index 47a36f3..b4c548f 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -508,12 +508,14 @@ def setup_gradio():
 						training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
 						verbose_training = gr.Checkbox(label="Verbose Console Output")
 						training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
-						training_keep_x_past_datasets = gr.Slider(label="Keep X Previous Datasets", minimum=0, maximum=8, value=0)
+						training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0)
 
 						training_loss_graph = gr.LinePlot(label="Loss Rates",
 							x="iteration",
-							y="loss_gpt_total",
+							y="loss",
 							title="Loss Rates",
+							color="type",
+							tooltip=['iteration', 'loss', 'type'],
 							width=600,
 							height=350
 						)
@@ -539,7 +541,7 @@ def setup_gradio():
 				with gr.Column():
 					exec_inputs = exec_inputs + [
 						gr.Number(label="Sample Batch Size", precision=0, value=args.sample_batch_size),
-						gr.Number(label="Concurrency Count", precision=0, value=args.concurrency_count),
+						gr.Number(label="Gradio Concurrency Count", precision=0, value=args.concurrency_count),
 						gr.Number(label="Output Sample Rate", precision=0, value=args.output_sample_rate),
 						gr.Slider(label="Output Volume", minimum=0, maximum=2, value=args.output_volume),
 					]