undo separating mask and stop token, this causes bigly problems...
This commit is contained in:
parent
3019c88799
commit
504b1ae832
|
@ -547,7 +547,6 @@ class Base(nn.Module):
|
|||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||
|
||||
self.stop_token = self.n_audio_tokens # id 1024
|
||||
self.mask_token = self.n_audio_tokens + 1 # id 1024
|
||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||
self.version = self.config.version if self.config is not None else 5
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||
|
@ -716,7 +715,7 @@ class Base(nn.Module):
|
|||
if self.version >= 7:
|
||||
if monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
|
@ -727,7 +726,7 @@ class Base(nn.Module):
|
|||
token_dim=d_model,
|
||||
)
|
||||
self.resps_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
|
@ -1311,9 +1310,9 @@ class Base(nn.Module):
|
|||
elif name == "resp":
|
||||
if self.version >= 7:
|
||||
if self.audio_emb is not None:
|
||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||
else:
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||
# if training NAR-len RVQ level 0
|
||||
elif dropout_mask is not None:
|
||||
embedding = self.resps_emb(
|
||||
|
|
Loading…
Reference in New Issue
Block a user