diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3e39c61..e268752 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -370,16 +370,12 @@ class AudioEncoder(nn.Module): class AudioDecoder(nn.Module): def __init__( self, - levels, d_model, hidden_size, vocab_size, ): super().__init__() - hidden_size *= levels - vocab_size *= levels - self.vocab_size = vocab_size self.up = nn.Linear( d_model, hidden_size ) self.down = nn.Linear( hidden_size, vocab_size ) @@ -715,8 +711,6 @@ class Base(nn.Module): self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way self.monolithic_audio_encoder = False # monolithic sounds bad if self.version >= 7: - dec_dim = d_model * 4 - if self.monolithic_audio_encoder: self.audio_emb = AudioEncoder( n_tokens=n_audio_tokens + 1, # masked token @@ -736,10 +730,9 @@ class Base(nn.Module): ) self.audio_decoder = AudioDecoder( - self.n_resp_levels, d_model, - dec_dim, - n_audio_tokens + 1, + d_model * 2, + (n_audio_tokens + 1) * self.n_resp_levels, ) if attention_backend == "auto":