diff --git a/vall_e/models/base.py b/vall_e/models/base.py
index 3e39c61..e268752 100755
--- a/vall_e/models/base.py
+++ b/vall_e/models/base.py
@@ -370,16 +370,12 @@ class AudioEncoder(nn.Module):
 class AudioDecoder(nn.Module):
 	def __init__(
 		self,
-		levels,
 		d_model,
 		hidden_size,
 		vocab_size,
 	):
 		super().__init__()
 
-		hidden_size *= levels
-		vocab_size *= levels
-
 		self.vocab_size = vocab_size
 		self.up = nn.Linear( d_model, hidden_size )
 		self.down = nn.Linear( hidden_size, vocab_size )
@@ -715,8 +711,6 @@ class Base(nn.Module):
 		self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way
 		self.monolithic_audio_encoder = False # monolithic sounds bad
 		if self.version >= 7:
-			dec_dim = d_model * 4
-
 			if self.monolithic_audio_encoder:
 				self.audio_emb = AudioEncoder(
 					n_tokens=n_audio_tokens + 1, # masked token
@@ -736,10 +730,9 @@ class Base(nn.Module):
 				)
 
 			self.audio_decoder = AudioDecoder(
-				self.n_resp_levels,
 				d_model,
-				dec_dim,
-				n_audio_tokens + 1,
+				d_model * 2,
+				(n_audio_tokens + 1) * self.n_resp_levels,
 			)
 
 		if attention_backend == "auto":