added backwards compat flag
This commit is contained in:
parent
ab5134f385
commit
c47fc3274e
|
@ -163,6 +163,7 @@ class Model:
|
|||
arch_type: str = "transformer"
|
||||
training: bool = True
|
||||
interleave: bool = False
|
||||
use_multiembedding: bool = True # nasty bandaid I got myself into
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
|
|
@ -156,10 +156,6 @@ class Base(nn.Module):
|
|||
def n_embeddings(self) -> int:
|
||||
return self.n_resp_levels if self.monolithic else 1
|
||||
|
||||
@property
|
||||
def use_old_embeddings(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
if not self.causal:
|
||||
|
@ -187,11 +183,15 @@ class Base(nn.Module):
|
|||
p_dropout: float = 0.1,
|
||||
|
||||
config = None,
|
||||
use_multiembedding = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||
|
||||
if self.config is not None and hasattr(self.config, "use_multiembedding"):
|
||||
use_multiembedding = self.config.use_multiembedding
|
||||
|
||||
self.n_tokens = n_tokens
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
|
@ -205,14 +205,14 @@ class Base(nn.Module):
|
|||
|
||||
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
||||
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
||||
if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings:
|
||||
if self.n_embeddings == self.n_prom_levels or not use_multiembedding:
|
||||
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
else:
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
|
||||
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
||||
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
||||
if self.n_embeddings > 1 or not self.use_old_embeddings:
|
||||
if self.n_embeddings > 1 or not use_multiembedding:
|
||||
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
||||
else:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
|
Loading…
Reference in New Issue
Block a user