diff --git a/data/config.yaml b/data/config.yaml index ad76bfb..cebc926 100644 --- a/data/config.yaml +++ b/data/config.yaml @@ -2,24 +2,31 @@ sample_rate: 24_000 # 44_000 for dac audio_backend: "vocos" # or dac models: -- name: "ar+nar" - size: "full" - resp_levels: 8 - prom_levels: 8 - tasks: 8 - langs: 2 - tones: 1 - arch_type: llama - training: True - version: 5 - attention: flash_attention_2 - dropout: 0.1 - experimental: False +- name: "ar+nar" # vanity name + size: "full" # model dimensionality + resp_levels: 8 # RVQ levels this model targets + prom_levels: 8 # should always be the above + tasks: 8 # tasks this model can attend to, only tts is supported at the moment + langs: 2 # languages this model supports, semi-unused at the moment + tones: 1 # tones this model supports, currently unused + arch_type: llama # underlying LLM arch to use, currently focusing on llama + training: True # signals this model is to be trained + version: 5 # helps keep backwards compatibility for when I add new things to the model + attention: auto # attention mechanism to use, "auto" for safety + dropout: 0.1 # percentage of the model to disable during training + # factors for split loss values, remove to have a unified loss calculation loss_factors: - text: 0.1 - prom: 0.0 - resp: 1.0 + text: 0.1 # text phoneme portion of the sequence + prom: 0.0 # input prompt portion of the sequence + resp: 1.0 # output audio portin of the sequence + + # experimental settings + experimental: + hf: False # uses vall_e.models.experimental, a wrapper around a HF model that could technically be used for non-pytorch backends later + interleave: False # interleaves RVQ levels, only works with above for now + audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings + audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters hyperparameters: autotune: False diff --git a/vall_e/config.py b/vall_e/config.py index 256e962..fc3b144 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -201,8 +201,10 @@ class ModelExperimentalSettings: interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal) split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level - audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings + audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings kv_heads: int = 0 # MHA or GQA (for supported backends) + p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely + rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range # I really need to clean this up @dataclass() @@ -225,9 +227,6 @@ class Model: loss_factors: dict = field(default_factory=lambda: {}) capabilities: list = field(default_factory=lambda: ["ar", "nar"]) - p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely - rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range - experimental: dict | ModelExperimentalSettings | None = None # experimental settings def get(self, name=None): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index b8c8d8c..04a6815 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -29,26 +29,10 @@ class AR_NAR(Base): return self.config.capabilities return cfg.model.capabilities - @property - def quant_level_range(self) -> list[int]: - if hasattr(self, "config") and self.config.rvq_level_range: - return self.config.rvq_level_range - return [ 0 if self.causal else 1, self.n_resp_levels ] - @property def causal(self): return "ar" in self.capabilities - @property - def norm_type(self): - return "ln" # if self.n_resp_levels == 1 else "adaln" - - @property - def arch_type(self) -> str: - if hasattr(self, "config") and self.config: - return self.config.arch_type - return cfg.model.arch_type - @property def n_prom_levels(self) -> int: if hasattr(self, "config") and self.config: @@ -72,12 +56,6 @@ class AR_NAR(Base): if hasattr(self, "config") and self.config: return self.config.tasks return cfg.model.tasks - - @property - def p_rvq_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.p_rvq_levels - return cfg.model.p_rvq_levels @property def n_langs(self) -> int: @@ -159,16 +137,18 @@ class AR_NAR(Base): def sample_task(): return "tts" + p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" + # generate task list to train against task_list = [ sample_task() for _ in range(batch_size) ] # determines which RVQ level to target per batch - quant_level_range = self.quant_level_range + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ] - if self.p_rvq_levels == "equal": + if p_rvq_levels == "equal": # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] - else: # if self.p_rvq_levels == "auto": + else: # if p_rvq_levels == "auto": # makes higher levels less likely def generate( lo=0, hi=8 ): index = lo diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 5b211f6..1558c95 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -282,14 +282,6 @@ class Base(nn.Module): def causal(self) -> bool: raise NotImplementedError - @property - def arch_type(self) -> str: - raise NotImplementedError - - @property - def norm_type(self): - raise NotImplementedError - @property def n_prom_levels(self) -> int: raise NotImplementedError @@ -377,6 +369,9 @@ class Base(nn.Module): self.l_padding = l_padding n_prom_tokens = n_audio_tokens + arch_type = self.config.arch_type if self.config is not None else "llama" + + self.arch_type = arch_type # check if requested arch is unavailable if self.arch_type in ERROR_ARCHES: @@ -392,7 +387,7 @@ class Base(nn.Module): audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False - audio_embeddings_mode = self.config.experimental.audio_embeddings_mode if self.config is not None else "" + audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -420,12 +415,12 @@ class Base(nn.Module): self.proms_emb = AudioEmbedding( [n_prom_tokens] * self.n_prom_levels, d_model, sums=audio_embedding_sums, - external_mode=audio_embeddings_mode, + external_mode=audio_embedding_mode, ) self.resps_emb = AudioEmbedding( l_tokens, d_model, sums=audio_embedding_sums, - external_mode=audio_embeddings_mode, + external_mode=audio_embedding_mode, ) # useless since I actually removed using these with the input processing overhaul... @@ -471,7 +466,7 @@ class Base(nn.Module): n_heads=n_heads, p_dropout=p_dropout if training else 0.0, causal=self.causal, - norm_type=self.norm_type, + norm_type="ln", # adaln n_levels=self.n_resp_levels, ) for _ in range(n_layers) ]) elif self.arch_type in ["mistral", "mixtral"]: diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 8acceca..d5f0e9c 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -27,26 +27,10 @@ class NAR(Base): return self.config.capabilities return cfg.model.capabilities - @property - def quant_level_range(self) -> list[int]: - if hasattr(self, "config") and self.config.rvq_level_range: - return self.config.rvq_level_range - return [ 0 if self.causal else 1, self.n_resp_levels ] - @property def causal(self): return "len" in self.capabilities - @property - def norm_type(self): - return "ln" # if self.n_resp_levels == 1 else "adaln" - - @property - def arch_type(self) -> str: - if hasattr(self, "config") and self.config: - return self.config.arch_type - return cfg.model.arch_type - @property def n_prom_levels(self) -> int: if hasattr(self, "config") and self.config: @@ -71,12 +55,6 @@ class NAR(Base): return self.config.tasks return cfg.model.tasks - @property - def p_rvq_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.p_rvq_levels - return cfg.model.p_rvq_levels - @property def n_langs(self) -> int: if hasattr(self, "config") and self.config: @@ -159,12 +137,14 @@ class NAR(Base): task_list = [ sample_task() for _ in range(batch_size) ] # determines which RVQ level to target per batch - quant_level_range = self.quant_level_range + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ] - if self.p_rvq_levels == "equal": + p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" + + if p_rvq_levels == "equal": # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] - else: # if self.p_rvq_levels == "auto": + else: # if p_rvq_levels == "auto": # makes higher levels less likely def generate( lo=0, hi=8 ): index = lo