diff --git a/vall_e/config.py b/vall_e/config.py
index c64880a..e4c29cb 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -569,6 +569,7 @@ class DeepSpeed:
 
 	loss_scale_window: int = 1000
 	min_loss_scale: float = 32768.0
+	max_loss_scale: float = 1048576.0
 	loss_scale = 0.0
 
 	config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py
index 7d63f56..7d9388e 100755
--- a/vall_e/engines/base.py
+++ b/vall_e/engines/base.py
@@ -282,6 +282,20 @@ class Engine():
 			elif 'lr' in param_group:
 				param_group['lr'] = lr
 
+	def get_loss_scale(self):
+		if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
+			return 1
+
+		return self.loss_scaler.get_scale()
+
+	def set_loss_scale(self, value):
+		if not hasattr(self, "loss_scaler") or self.loss_scaler is None:
+			return
+		
+		"""
+		self.optimizer.loss_scale = value
+		"""
+
 	def get_global_grad_norm(self):
 		return self._global_grad_norm
 
@@ -457,6 +471,12 @@ class Engines(dict[str, Engine]):
 				continue
 			engine.set_lr(lr)
 
+	def set_loss_scale(self, lr):
+		for engine in self.values():
+			if not engine._training:
+				continue
+			engine.set_loss_scale(lr)
+
 	def _update(self):
 		for engine in self.values():
 			self._global_step = max(self._global_step, engine.global_step)
@@ -584,11 +604,11 @@ class Engines(dict[str, Engine]):
 			elapsed_time = time.time() - start_time
 			total_elapsed_time += elapsed_time
 			grad_norm = engine.get_global_grad_norm()
-			loss_scale = 1
-			if hasattr(engine.optimizer, "loss_scale") and engine.optimizer.loss_scale is not None:
-				loss_scale = engine.optimizer.loss_scale
-			elif hasattr(engine, "loss_scaler") and engine.loss_scaler is not None:
-				loss_scale = engine.loss_scaler.get_scale()
+			loss_scale = engine.get_loss_scale()
+
+			if cfg.trainer.deepspeed.max_loss_scale > 0 and loss_scale > cfg.trainer.deepspeed.max_loss_scale:
+				_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
+				engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
 
 			if grad_norm is not None:
 				grad_norm /= loss_scale
diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py
index 85ec63b..27e4276 100755
--- a/vall_e/engines/deepspeed.py
+++ b/vall_e/engines/deepspeed.py
@@ -119,6 +119,19 @@ class Engine(DeepSpeedEngine):
 		except Exception as e:
 			_logger.warning(str(e))
 
+	# cur_scale, because _get_loss_scale has a typo in the def and I can't be assed to inject a fix into it or push a PR
+	def get_loss_scale(self):
+		if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
+			return 1.0
+
+		return self.optimizer.cur_scale
+
+	def set_loss_scale(self, value):
+		if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
+			return
+		
+		self.optimizer.cur_scale = value
+
 	# we'll just have to live with the LoRA weights living within our main weights
 	# they're easy to extract anyways
 	def load_checkpoint(self, load_dir, **kwargs ):
diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py
index 09609e7..a563e29 100755
--- a/vall_e/utils/trainer.py
+++ b/vall_e/utils/trainer.py
@@ -233,7 +233,16 @@ def train(
 					engines.set_lr(rate)
 					_logger.info(f"Updating LR to: {rate}")
 				except Exception as e:
-					_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
+					_logger.warning(f"Failed to set LR to: {rate}, {str(e)}")
+
+			if "loss_scale" in command:
+				value = float(command.split(" ")[-1])
+				try:
+					engines.set_loss_scale(value)
+					_logger.info(f"Updating loss scale to: {value}")
+				except Exception as e:
+					raise e
+					_logger.warning(f"Failed to set loss scale to: {value}, {str(e)}")
 
 			if "export" in command:
 				train_dl.dataset.save_state_dict()