added backwards compat flag
This commit is contained in:
parent
ab5134f385
commit
c47fc3274e
|
@ -163,6 +163,7 @@ class Model:
|
||||||
arch_type: str = "transformer"
|
arch_type: str = "transformer"
|
||||||
training: bool = True
|
training: bool = True
|
||||||
interleave: bool = False
|
interleave: bool = False
|
||||||
|
use_multiembedding: bool = True # nasty bandaid I got myself into
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
|
|
|
@ -156,10 +156,6 @@ class Base(nn.Module):
|
||||||
def n_embeddings(self) -> int:
|
def n_embeddings(self) -> int:
|
||||||
return self.n_resp_levels if self.monolithic else 1
|
return self.n_resp_levels if self.monolithic else 1
|
||||||
|
|
||||||
@property
|
|
||||||
def use_old_embeddings(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
if not self.causal:
|
if not self.causal:
|
||||||
|
@ -187,11 +183,15 @@ class Base(nn.Module):
|
||||||
p_dropout: float = 0.1,
|
p_dropout: float = 0.1,
|
||||||
|
|
||||||
config = None,
|
config = None,
|
||||||
|
use_multiembedding = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||||
|
|
||||||
|
if self.config is not None and hasattr(self.config, "use_multiembedding"):
|
||||||
|
use_multiembedding = self.config.use_multiembedding
|
||||||
|
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
@ -205,14 +205,14 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
||||||
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
||||||
if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings:
|
if self.n_embeddings == self.n_prom_levels or not use_multiembedding:
|
||||||
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
|
|
||||||
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
||||||
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
||||||
if self.n_embeddings > 1 or not self.use_old_embeddings:
|
if self.n_embeddings > 1 or not use_multiembedding:
|
||||||
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
||||||
else:
|
else:
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user