diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c4ec0c5..aa1a5ba 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -547,7 +547,6 @@ class Base(nn.Module): self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True self.stop_token = self.n_audio_tokens # id 1024 - self.mask_token = self.n_audio_tokens + 1 # id 1024 self.causal = "ar" in self.capabilities or "len" in self.capabilities self.version = self.config.version if self.config is not None else 5 self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0) @@ -716,7 +715,7 @@ class Base(nn.Module): if self.version >= 7: if monolithic_audio_encoder: self.audio_emb = AudioEncoder( - n_tokens=n_audio_tokens + 2, # stop + masked token + n_tokens=n_audio_tokens + 1, # masked token n_levels=self.n_resp_levels, token_dim=d_model, ) @@ -727,7 +726,7 @@ class Base(nn.Module): token_dim=d_model, ) self.resps_emb = AudioEncoder( - n_tokens=n_audio_tokens + 2, # stop + masked token + n_tokens=n_audio_tokens + 1, # masked token n_levels=self.n_resp_levels, token_dim=d_model, ) @@ -1311,9 +1310,9 @@ class Base(nn.Module): elif name == "resp": if self.version >= 7: if self.audio_emb is not None: - embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) + embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) else: - embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) + embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) # if training NAR-len RVQ level 0 elif dropout_mask is not None: embedding = self.resps_emb(