more cleanup

This commit is contained in:
mrq 2024-06-30 11:00:12 -05:00
parent bc2a6fa756
commit dced595391
5 changed files with 43 additions and 82 deletions

View File

@ -2,24 +2,31 @@ sample_rate: 24_000 # 44_000 for dac
audio_backend: "vocos" # or dac audio_backend: "vocos" # or dac
models: models:
- name: "ar+nar" - name: "ar+nar" # vanity name
size: "full" size: "full" # model dimensionality
resp_levels: 8 resp_levels: 8 # RVQ levels this model targets
prom_levels: 8 prom_levels: 8 # should always be the above
tasks: 8 tasks: 8 # tasks this model can attend to, only tts is supported at the moment
langs: 2 langs: 2 # languages this model supports, semi-unused at the moment
tones: 1 tones: 1 # tones this model supports, currently unused
arch_type: llama arch_type: llama # underlying LLM arch to use, currently focusing on llama
training: True training: True # signals this model is to be trained
version: 5 version: 5 # helps keep backwards compatibility for when I add new things to the model
attention: flash_attention_2 attention: auto # attention mechanism to use, "auto" for safety
dropout: 0.1 dropout: 0.1 # percentage of the model to disable during training
experimental: False
# factors for split loss values, remove to have a unified loss calculation
loss_factors: loss_factors:
text: 0.1 text: 0.1 # text phoneme portion of the sequence
prom: 0.0 prom: 0.0 # input prompt portion of the sequence
resp: 1.0 resp: 1.0 # output audio portin of the sequence
# experimental settings
experimental:
hf: False # uses vall_e.models.experimental, a wrapper around a HF model that could technically be used for non-pytorch backends later
interleave: False # interleaves RVQ levels, only works with above for now
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
hyperparameters: hyperparameters:
autotune: False autotune: False

View File

@ -201,8 +201,10 @@ class ModelExperimentalSettings:
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal) interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends) kv_heads: int = 0 # MHA or GQA (for supported backends)
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
# I really need to clean this up # I really need to clean this up
@dataclass() @dataclass()
@ -225,9 +227,6 @@ class Model:
loss_factors: dict = field(default_factory=lambda: {}) loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"]) capabilities: list = field(default_factory=lambda: ["ar", "nar"])
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
experimental: dict | ModelExperimentalSettings | None = None # experimental settings experimental: dict | ModelExperimentalSettings | None = None # experimental settings
def get(self, name=None): def get(self, name=None):

View File

@ -29,26 +29,10 @@ class AR_NAR(Base):
return self.config.capabilities return self.config.capabilities
return cfg.model.capabilities return cfg.model.capabilities
@property
def quant_level_range(self) -> list[int]:
if hasattr(self, "config") and self.config.rvq_level_range:
return self.config.rvq_level_range
return [ 0 if self.causal else 1, self.n_resp_levels ]
@property @property
def causal(self): def causal(self):
return "ar" in self.capabilities return "ar" in self.capabilities
@property
def norm_type(self):
return "ln" # if self.n_resp_levels == 1 else "adaln"
@property
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.model.arch_type
@property @property
def n_prom_levels(self) -> int: def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -72,12 +56,6 @@ class AR_NAR(Base):
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
return self.config.tasks return self.config.tasks
return cfg.model.tasks return cfg.model.tasks
@property
def p_rvq_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.p_rvq_levels
return cfg.model.p_rvq_levels
@property @property
def n_langs(self) -> int: def n_langs(self) -> int:
@ -159,16 +137,18 @@ class AR_NAR(Base):
def sample_task(): def sample_task():
return "tts" return "tts"
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
# generate task list to train against # generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ] task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch # determines which RVQ level to target per batch
quant_level_range = self.quant_level_range quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
if self.p_rvq_levels == "equal": if p_rvq_levels == "equal":
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if self.p_rvq_levels == "auto": else: # if p_rvq_levels == "auto":
# makes higher levels less likely # makes higher levels less likely
def generate( lo=0, hi=8 ): def generate( lo=0, hi=8 ):
index = lo index = lo

View File

@ -282,14 +282,6 @@ class Base(nn.Module):
def causal(self) -> bool: def causal(self) -> bool:
raise NotImplementedError raise NotImplementedError
@property
def arch_type(self) -> str:
raise NotImplementedError
@property
def norm_type(self):
raise NotImplementedError
@property @property
def n_prom_levels(self) -> int: def n_prom_levels(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -377,6 +369,9 @@ class Base(nn.Module):
self.l_padding = l_padding self.l_padding = l_padding
n_prom_tokens = n_audio_tokens n_prom_tokens = n_audio_tokens
arch_type = self.config.arch_type if self.config is not None else "llama"
self.arch_type = arch_type
# check if requested arch is unavailable # check if requested arch is unavailable
if self.arch_type in ERROR_ARCHES: if self.arch_type in ERROR_ARCHES:
@ -392,7 +387,7 @@ class Base(nn.Module):
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
audio_embeddings_mode = self.config.experimental.audio_embeddings_mode if self.config is not None else "" audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None self.langs_emb = None
@ -420,12 +415,12 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model, [n_prom_tokens] * self.n_prom_levels, d_model,
sums=audio_embedding_sums, sums=audio_embedding_sums,
external_mode=audio_embeddings_mode, external_mode=audio_embedding_mode,
) )
self.resps_emb = AudioEmbedding( self.resps_emb = AudioEmbedding(
l_tokens, d_model, l_tokens, d_model,
sums=audio_embedding_sums, sums=audio_embedding_sums,
external_mode=audio_embeddings_mode, external_mode=audio_embedding_mode,
) )
# useless since I actually removed using these with the input processing overhaul... # useless since I actually removed using these with the input processing overhaul...
@ -471,7 +466,7 @@ class Base(nn.Module):
n_heads=n_heads, n_heads=n_heads,
p_dropout=p_dropout if training else 0.0, p_dropout=p_dropout if training else 0.0,
causal=self.causal, causal=self.causal,
norm_type=self.norm_type, norm_type="ln", # adaln
n_levels=self.n_resp_levels, n_levels=self.n_resp_levels,
) for _ in range(n_layers) ]) ) for _ in range(n_layers) ])
elif self.arch_type in ["mistral", "mixtral"]: elif self.arch_type in ["mistral", "mixtral"]:

View File

@ -27,26 +27,10 @@ class NAR(Base):
return self.config.capabilities return self.config.capabilities
return cfg.model.capabilities return cfg.model.capabilities
@property
def quant_level_range(self) -> list[int]:
if hasattr(self, "config") and self.config.rvq_level_range:
return self.config.rvq_level_range
return [ 0 if self.causal else 1, self.n_resp_levels ]
@property @property
def causal(self): def causal(self):
return "len" in self.capabilities return "len" in self.capabilities
@property
def norm_type(self):
return "ln" # if self.n_resp_levels == 1 else "adaln"
@property
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.model.arch_type
@property @property
def n_prom_levels(self) -> int: def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -71,12 +55,6 @@ class NAR(Base):
return self.config.tasks return self.config.tasks
return cfg.model.tasks return cfg.model.tasks
@property
def p_rvq_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.p_rvq_levels
return cfg.model.p_rvq_levels
@property @property
def n_langs(self) -> int: def n_langs(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -159,12 +137,14 @@ class NAR(Base):
task_list = [ sample_task() for _ in range(batch_size) ] task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch # determines which RVQ level to target per batch
quant_level_range = self.quant_level_range quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
if self.p_rvq_levels == "equal": p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
if p_rvq_levels == "equal":
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
else: # if self.p_rvq_levels == "auto": else: # if p_rvq_levels == "auto":
# makes higher levels less likely # makes higher levels less likely
def generate( lo=0, hi=8 ): def generate( lo=0, hi=8 ):
index = lo index = lo