From 504b1ae832ef5c263884c3c238cf494c5d5cac6d Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 23 Feb 2025 11:49:49 -0600 Subject: [PATCH] undo separating mask and stop token, this causes bigly problems... --- vall_e/models/base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c4ec0c5..aa1a5ba 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -547,7 +547,6 @@ class Base(nn.Module): self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True self.stop_token = self.n_audio_tokens # id 1024 - self.mask_token = self.n_audio_tokens + 1 # id 1024 self.causal = "ar" in self.capabilities or "len" in self.capabilities self.version = self.config.version if self.config is not None else 5 self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0) @@ -716,7 +715,7 @@ class Base(nn.Module): if self.version >= 7: if monolithic_audio_encoder: self.audio_emb = AudioEncoder( - n_tokens=n_audio_tokens + 2, # stop + masked token + n_tokens=n_audio_tokens + 1, # masked token n_levels=self.n_resp_levels, token_dim=d_model, ) @@ -727,7 +726,7 @@ class Base(nn.Module): token_dim=d_model, ) self.resps_emb = AudioEncoder( - n_tokens=n_audio_tokens + 2, # stop + masked token + n_tokens=n_audio_tokens + 1, # masked token n_levels=self.n_resp_levels, token_dim=d_model, ) @@ -1311,9 +1310,9 @@ class Base(nn.Module): elif name == "resp": if self.version >= 7: if self.audio_emb is not None: - embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) + embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) else: - embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token ) + embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) # if training NAR-len RVQ level 0 elif dropout_mask is not None: embedding = self.resps_emb(