undo separating mask and stop token, this causes bigly problems...
This commit is contained in:
parent
3019c88799
commit
504b1ae832
|
@ -547,7 +547,6 @@ class Base(nn.Module):
|
||||||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||||
|
|
||||||
self.stop_token = self.n_audio_tokens # id 1024
|
self.stop_token = self.n_audio_tokens # id 1024
|
||||||
self.mask_token = self.n_audio_tokens + 1 # id 1024
|
|
||||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||||
self.version = self.config.version if self.config is not None else 5
|
self.version = self.config.version if self.config is not None else 5
|
||||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||||
|
@ -716,7 +715,7 @@ class Base(nn.Module):
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
if monolithic_audio_encoder:
|
if monolithic_audio_encoder:
|
||||||
self.audio_emb = AudioEncoder(
|
self.audio_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
n_tokens=n_audio_tokens + 1, # masked token
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
)
|
)
|
||||||
|
@ -727,7 +726,7 @@ class Base(nn.Module):
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
)
|
)
|
||||||
self.resps_emb = AudioEncoder(
|
self.resps_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
n_tokens=n_audio_tokens + 1, # masked token
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
)
|
)
|
||||||
|
@ -1311,9 +1310,9 @@ class Base(nn.Module):
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
if self.audio_emb is not None:
|
if self.audio_emb is not None:
|
||||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||||
else:
|
else:
|
||||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||||
# if training NAR-len RVQ level 0
|
# if training NAR-len RVQ level 0
|
||||||
elif dropout_mask is not None:
|
elif dropout_mask is not None:
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user