more cleanup
This commit is contained in:
parent
bc2a6fa756
commit
dced595391
|
@ -2,24 +2,31 @@ sample_rate: 24_000 # 44_000 for dac
|
|||
audio_backend: "vocos" # or dac
|
||||
|
||||
models:
|
||||
- name: "ar+nar"
|
||||
size: "full"
|
||||
resp_levels: 8
|
||||
prom_levels: 8
|
||||
tasks: 8
|
||||
langs: 2
|
||||
tones: 1
|
||||
arch_type: llama
|
||||
training: True
|
||||
version: 5
|
||||
attention: flash_attention_2
|
||||
dropout: 0.1
|
||||
experimental: False
|
||||
- name: "ar+nar" # vanity name
|
||||
size: "full" # model dimensionality
|
||||
resp_levels: 8 # RVQ levels this model targets
|
||||
prom_levels: 8 # should always be the above
|
||||
tasks: 8 # tasks this model can attend to, only tts is supported at the moment
|
||||
langs: 2 # languages this model supports, semi-unused at the moment
|
||||
tones: 1 # tones this model supports, currently unused
|
||||
arch_type: llama # underlying LLM arch to use, currently focusing on llama
|
||||
training: True # signals this model is to be trained
|
||||
version: 5 # helps keep backwards compatibility for when I add new things to the model
|
||||
attention: auto # attention mechanism to use, "auto" for safety
|
||||
dropout: 0.1 # percentage of the model to disable during training
|
||||
|
||||
# factors for split loss values, remove to have a unified loss calculation
|
||||
loss_factors:
|
||||
text: 0.1
|
||||
prom: 0.0
|
||||
resp: 1.0
|
||||
text: 0.1 # text phoneme portion of the sequence
|
||||
prom: 0.0 # input prompt portion of the sequence
|
||||
resp: 1.0 # output audio portin of the sequence
|
||||
|
||||
# experimental settings
|
||||
experimental:
|
||||
hf: False # uses vall_e.models.experimental, a wrapper around a HF model that could technically be used for non-pytorch backends later
|
||||
interleave: False # interleaves RVQ levels, only works with above for now
|
||||
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
|
||||
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
|
||||
|
||||
hyperparameters:
|
||||
autotune: False
|
||||
|
|
|
@ -201,8 +201,10 @@ class ModelExperimentalSettings:
|
|||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
|
||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
|
@ -225,9 +227,6 @@ class Model:
|
|||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
|
||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
|
||||
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
|
||||
|
||||
def get(self, name=None):
|
||||
|
|
|
@ -29,26 +29,10 @@ class AR_NAR(Base):
|
|||
return self.config.capabilities
|
||||
return cfg.model.capabilities
|
||||
|
||||
@property
|
||||
def quant_level_range(self) -> list[int]:
|
||||
if hasattr(self, "config") and self.config.rvq_level_range:
|
||||
return self.config.rvq_level_range
|
||||
return [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
@property
|
||||
def causal(self):
|
||||
return "ar" in self.capabilities
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
return "ln" # if self.n_resp_levels == 1 else "adaln"
|
||||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.arch_type
|
||||
return cfg.model.arch_type
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -73,12 +57,6 @@ class AR_NAR(Base):
|
|||
return self.config.tasks
|
||||
return cfg.model.tasks
|
||||
|
||||
@property
|
||||
def p_rvq_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.p_rvq_levels
|
||||
return cfg.model.p_rvq_levels
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -159,16 +137,18 @@ class AR_NAR(Base):
|
|||
def sample_task():
|
||||
return "tts"
|
||||
|
||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||
|
||||
# generate task list to train against
|
||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.quant_level_range
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
if self.p_rvq_levels == "equal":
|
||||
if p_rvq_levels == "equal":
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if self.p_rvq_levels == "auto":
|
||||
else: # if p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
|
|
|
@ -282,14 +282,6 @@ class Base(nn.Module):
|
|||
def causal(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
@ -377,6 +369,9 @@ class Base(nn.Module):
|
|||
self.l_padding = l_padding
|
||||
|
||||
n_prom_tokens = n_audio_tokens
|
||||
arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||
|
||||
self.arch_type = arch_type
|
||||
|
||||
# check if requested arch is unavailable
|
||||
if self.arch_type in ERROR_ARCHES:
|
||||
|
@ -392,7 +387,7 @@ class Base(nn.Module):
|
|||
|
||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||
audio_embeddings_mode = self.config.experimental.audio_embeddings_mode if self.config is not None else ""
|
||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -420,12 +415,12 @@ class Base(nn.Module):
|
|||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=audio_embeddings_mode,
|
||||
external_mode=audio_embedding_mode,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
l_tokens, d_model,
|
||||
sums=audio_embedding_sums,
|
||||
external_mode=audio_embeddings_mode,
|
||||
external_mode=audio_embedding_mode,
|
||||
)
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
|
@ -471,7 +466,7 @@ class Base(nn.Module):
|
|||
n_heads=n_heads,
|
||||
p_dropout=p_dropout if training else 0.0,
|
||||
causal=self.causal,
|
||||
norm_type=self.norm_type,
|
||||
norm_type="ln", # adaln
|
||||
n_levels=self.n_resp_levels,
|
||||
) for _ in range(n_layers) ])
|
||||
elif self.arch_type in ["mistral", "mixtral"]:
|
||||
|
|
|
@ -27,26 +27,10 @@ class NAR(Base):
|
|||
return self.config.capabilities
|
||||
return cfg.model.capabilities
|
||||
|
||||
@property
|
||||
def quant_level_range(self) -> list[int]:
|
||||
if hasattr(self, "config") and self.config.rvq_level_range:
|
||||
return self.config.rvq_level_range
|
||||
return [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
@property
|
||||
def causal(self):
|
||||
return "len" in self.capabilities
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
return "ln" # if self.n_resp_levels == 1 else "adaln"
|
||||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.arch_type
|
||||
return cfg.model.arch_type
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -71,12 +55,6 @@ class NAR(Base):
|
|||
return self.config.tasks
|
||||
return cfg.model.tasks
|
||||
|
||||
@property
|
||||
def p_rvq_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.p_rvq_levels
|
||||
return cfg.model.p_rvq_levels
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -159,12 +137,14 @@ class NAR(Base):
|
|||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.quant_level_range
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
if self.p_rvq_levels == "equal":
|
||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||
|
||||
if p_rvq_levels == "equal":
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
else: # if self.p_rvq_levels == "auto":
|
||||
else: # if p_rvq_levels == "auto":
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
|
|
Loading…
Reference in New Issue
Block a user