undo separating mask and stop token, this causes bigly problems...

This commit is contained in:
mrq 2025-02-23 11:49:49 -06:00
parent 3019c88799
commit 504b1ae832

View File

@ -547,7 +547,6 @@ class Base(nn.Module):
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.stop_token = self.n_audio_tokens # id 1024
self.mask_token = self.n_audio_tokens + 1 # id 1024
self.causal = "ar" in self.capabilities or "len" in self.capabilities
self.version = self.config.version if self.config is not None else 5
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
@ -716,7 +715,7 @@ class Base(nn.Module):
if self.version >= 7:
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
@ -727,7 +726,7 @@ class Base(nn.Module):
token_dim=d_model,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
@ -1311,9 +1310,9 @@ class Base(nn.Module):
elif name == "resp":
if self.version >= 7:
if self.audio_emb is not None:
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
else:
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
# if training NAR-len RVQ level 0
elif dropout_mask is not None:
embedding = self.resps_emb(