added backwards compat flag

This commit is contained in:
mrq 2023-09-07 17:12:17 -05:00
parent ab5134f385
commit c47fc3274e
2 changed files with 7 additions and 6 deletions

View File

@ -163,6 +163,7 @@ class Model:
arch_type: str = "transformer"
training: bool = True
interleave: bool = False
use_multiembedding: bool = True # nasty bandaid I got myself into
@property
def full_name(self):

View File

@ -156,10 +156,6 @@ class Base(nn.Module):
def n_embeddings(self) -> int:
return self.n_resp_levels if self.monolithic else 1
@property
def use_old_embeddings(self) -> bool:
return True
@property
def stop_token(self):
if not self.causal:
@ -187,11 +183,15 @@ class Base(nn.Module):
p_dropout: float = 0.1,
config = None,
use_multiembedding = False,
):
super().__init__()
self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
if self.config is not None and hasattr(self.config, "use_multiembedding"):
use_multiembedding = self.config.use_multiembedding
self.n_tokens = n_tokens
self.d_model = d_model
self.n_heads = n_heads
@ -205,14 +205,14 @@ class Base(nn.Module):
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings:
if self.n_embeddings == self.n_prom_levels or not use_multiembedding:
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
else:
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
if self.n_embeddings > 1 or not self.use_old_embeddings:
if self.n_embeddings > 1 or not use_multiembedding:
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
else:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)