diff --git a/src/utils.py b/src/utils.py
index 241cb7f..f8d008f 100755
--- a/src/utils.py
+++ b/src/utils.py
@@ -627,7 +627,10 @@ class TrainingState():
 		self.nan_detected = False
 
 		self.last_info_check_at = 0
-		self.statistics = []
+		self.statistics = {
+			'loss': [],
+			'lr': [],
+		}
 		self.losses = []
 		self.metrics = {
 			'step': "",
@@ -637,7 +640,7 @@ class TrainingState():
 
 		self.loss_milestones = [ 1.0, 0.15, 0.05 ]
 
-		self.load_losses()
+		self.load_statistics()
 		if keep_x_past_checkpoints > 0:
 			self.cleanup_old(keep=keep_x_past_checkpoints)
 		if start:
@@ -649,7 +652,7 @@ class TrainingState():
 		print("Spawning process: ", " ".join(self.cmd))
 		self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
 
-	def load_losses(self, update=False):
+	def load_statistics(self, update=False):
 		if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
 			return
 		try:
@@ -658,69 +661,40 @@ class TrainingState():
 		except Exception as e:
 			use_tensorboard = False
 
-		keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce']
+		keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
 		infos = {}
 		highest_step = self.last_info_check_at
 
 		if not update:
-			self.statistics = []
+			self.statistics['loss'] = []
+			self.statistics['lr'] = []
 
-		if use_tensorboard:
-			logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
-			if update:
-				logs = [logs[-1]]
+		logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
+		if update:
+			logs = [logs[-1]]
 
-			for log in logs:
-					ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
-					ea.Reload()
+		for log in logs:
+				ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
+				ea.Reload()
 
-					for key in keys:
-						try:
-							scalar = ea.Scalars(key)
-							for s in scalar:
-								if update and s.step <= self.last_info_check_at:
-									continue
-								highest_step = max( highest_step, s.step )
-								self.statistics.append( { "step": s.step, "value": s.value, "type": key } )
+				scalars = ea.Tags()['scalars']
 
-								if key == 'loss_gpt_total':
-									self.losses.append( { "step": s.step, "value": s.value, "type": key } )
-						except Exception as e:
-							pass
+				for key in keys:
+					if key not in scalars:
+						continue
 
-		else:
-			logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
-			if update:
-				logs = [logs[-1]]
-
-			for log in logs:
-				with open(log, 'r', encoding="utf-8") as f:
-					lines = f.readlines()
-					for line in lines:
-						if line.find('INFO: [epoch:') >= 0:
-							# easily rip out our stats...
-							match = re.findall(r'\b([a-z_0-9]+?)\b: +?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', line)
-							if not match or len(match) == 0:
+					try:
+						scalar = ea.Scalars(key)
+						for s in scalar:
+							if update and s.step <= self.last_info_check_at:
 								continue
-
-							info = {}
-							for k, v in match:
-								info[k] = float(v.replace(",", ""))
-
-							if 'iter' in info:
-								it = info['iter']
-								infos[it] = info
-
-			for k in infos:
-				if 'loss_gpt_total' in infos[k]:
-					for key in keys:
-						if update and int(k) <= self.last_info_check_at:
-							continue
-						highest_step = max( highest_step, s.step )
-						self.statistics.append({ "step": int(k), "value": infos[k][key], "type": key })
-
-						if key == "loss_gpt_total":
-							self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
+							highest_step = max( highest_step, s.step )
+							target = 'lr' if key == "learning_rate_gpt_0" else 'loss'
+							self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } )
+							if key == 'loss_gpt_total':
+								self.losses.append( { "step": s.step, "value": s.value, "type": key } )
+					except Exception as e:
+						pass
 
 		self.last_info_check_at = highest_step
 
@@ -784,7 +758,7 @@ class TrainingState():
 					for k, v in match:
 						self.info[k] = float(v.replace(",", ""))
 
-				self.load_losses(update=True)
+				self.load_statistics(update=True)
 				should_return = True
 
 				if 'epoch' in self.info:
@@ -1003,20 +977,26 @@ def run_training(config_path, verbose=False, keep_x_past_checkpoints=0, progress
 
 def update_training_dataplot(config_path=None):
 	global training_state
-	update = None
+	losses = None
+	lrs = None
 
 	if not training_state:
 		if config_path:
 			training_state = TrainingState(config_path=config_path, start=False)
-			if training_state.statistics:
-				update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
+			if len(training_state.statistics['loss']) > 0:
+				losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
+			if len(training_state.statistics['lr']) > 0:
+				lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
 			del training_state
 			training_state = None
-	elif training_state.statistics:
-		training_state.load_losses()
-		update = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=600, height=350,)
+	else:
+		training_state.load_statistics()
+		if len(training_state.statistics['loss']) > 0:
+			losses = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['loss']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
+		if len(training_state.statistics['lr']) > 0:
+			lrs = gr.LinePlot.update(value=pd.DataFrame(training_state.statistics['lr']), x_lim=[0,training_state.its], x="step", y="value", title="Training Metrics", color="type", tooltip=['step', 'value', 'type'], width=500, height=350,)
 
-	return update
+	return (losses, lrs)
 
 def reconnect_training(verbose=False, progress=gr.Progress(track_tqdm=True)):
 	global training_state
@@ -1363,9 +1343,11 @@ def save_training_settings( **kwargs ):
 	settings['iterations'] = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
 	messages.append(f"For {settings['epochs']} epochs with {lines} lines, iterating for {settings['iterations']} steps")
 
-	settings['print_rate'] = int(settings['print_rate'] * settings['iterations'] / settings['epochs'])
-	settings['save_rate'] = int(settings['save_rate'] * settings['iterations'] / settings['epochs'])
-	settings['validation_rate'] = int(settings['validation_rate'] * settings['iterations'] / settings['epochs'])
+	iterations_per_epoch = int(settings['iterations'] / settings['epochs'])
+
+	settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
+	settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
+	settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
 
 	settings['validation_batch_size'] = int(settings['batch_size'] / settings['gradient_accumulation_size'])
 
@@ -1407,16 +1389,31 @@ def save_training_settings( **kwargs ):
 		elif isinstance(settings['learning_rate_schedule'],str):
 			settings['learning_rate_schedule'] = json.loads(settings['learning_rate_schedule'])
 
-		settings['learning_rate_schedule'] = schedule_learning_rate( settings['iterations'] / settings['epochs'], settings['learning_rate_schedule'] )
+		settings['learning_rate_schedule'] = schedule_learning_rate( iterations_per_epoch, settings['learning_rate_schedule'] )
 
 		learning_rate_schema.append(f"  gen_lr_steps: {settings['learning_rate_schedule']}")
 		learning_rate_schema.append(f"  lr_gamma: 0.5")
 	elif settings['learning_rate_scheme'] == "CosineAnnealingLR_Restart":
-		learning_rate_schema.append(f"  T_period: [120000, 120000, 120000]")
-		learning_rate_schema.append(f"  warmup: 10000")
-		learning_rate_schema.append(f"  eta_min: .01")
-		learning_rate_schema.append(f"  restarts: [140000, 280000]")
-		learning_rate_schema.append(f"  restart_weights: [.5, .25]")
+		epochs = settings['epochs']
+		restarts = int(epochs / 2)
+
+		if 'learning_rate_period' not in settings:
+			settings['learning_rate_period'] = [ iterations_per_epoch for x in range(epochs) ]
+		if 'learning_rate_warmup' not in settings:
+			settings['learning_rate_warmup'] = 0
+		if 'learning_rate_min' not in settings:
+			settings['learning_rate_min'] = 1e-07
+		if 'learning_rate_restarts' not in settings:
+			settings['learning_rate_restarts'] = [ iterations_per_epoch * (x+1) * 2 for x in range(restarts) ] # [52, 104, 156, 208]
+		if 'learning_rate_restart_weights' not in settings:
+			settings['learning_rate_restart_weights'] = [ ( restarts - x - 1 ) / restarts for x in range(restarts) ] # [.75, .5, .25, .125]
+			settings['learning_rate_restart_weights'][-1] = settings['learning_rate_restart_weights'][-2] * 0.5
+
+		learning_rate_schema.append(f"  T_period: {settings['learning_rate_period']}")
+		learning_rate_schema.append(f"  warmup: !!float {settings['learning_rate_warmup']}")
+		learning_rate_schema.append(f"  eta_min: !!float {settings['learning_rate_min']}")
+		learning_rate_schema.append(f"  restarts: {settings['learning_rate_restarts']}")
+		learning_rate_schema.append(f"  restart_weights: {settings['learning_rate_restart_weights']}")
 	settings['learning_rate_scheme'] = "\n".join(learning_rate_schema)
 
 	"""
diff --git a/src/webui.py b/src/webui.py
index 6691759..8b5ddae 100755
--- a/src/webui.py
+++ b/src/webui.py
@@ -430,21 +430,7 @@ def setup_gradio():
 				with gr.Row():
 					with gr.Column():
 						training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
-						with gr.Row():
-							refresh_configs = gr.Button(value="Refresh Configurations")
-						
-						training_loss_graph = gr.LinePlot(label="Training Metrics",
-							x="step",
-							y="value",
-							title="Training Metrics",
-							color="type",
-							tooltip=['step', 'value', 'type'],
-							width=600,
-							height=350,
-						)
-						view_losses = gr.Button(value="View Losses")
-
-					with gr.Column():
+						refresh_configs = gr.Button(value="Refresh Configurations")
 						training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
 						verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
 						
@@ -453,6 +439,27 @@ def setup_gradio():
 							start_training_button = gr.Button(value="Train")
 							stop_training_button = gr.Button(value="Stop")
 							reconnect_training_button = gr.Button(value="Reconnect")
+						
+					with gr.Column():
+						training_loss_graph = gr.LinePlot(label="Training Metrics",
+							x="step",
+							y="value",
+							title="Training Metrics",
+							color="type",
+							tooltip=['step', 'value', 'type'],
+							width=500,
+							height=350,
+						)
+						training_lr_graph = gr.LinePlot(label="Training Metrics",
+							x="step",
+							y="value",
+							title="Training Metrics",
+							color="type",
+							tooltip=['step', 'value', 'type'],
+							width=500,
+							height=350,
+						)
+						view_losses = gr.Button(value="View Losses")
 		with gr.Tab("Settings"):
 			with gr.Row():
 				exec_inputs = []
@@ -650,6 +657,7 @@ def setup_gradio():
 			inputs=None,
 			outputs=[
 				training_loss_graph,
+				training_lr_graph,
 			],
 			show_progress=False,
 		)
@@ -661,6 +669,7 @@ def setup_gradio():
 			],
 			outputs=[
 				training_loss_graph,
+				training_lr_graph,
 			],
 		)