more cleanup
This commit is contained in:
parent
bc2a6fa756
commit
dced595391
|
@ -2,24 +2,31 @@ sample_rate: 24_000 # 44_000 for dac
|
||||||
audio_backend: "vocos" # or dac
|
audio_backend: "vocos" # or dac
|
||||||
|
|
||||||
models:
|
models:
|
||||||
- name: "ar+nar"
|
- name: "ar+nar" # vanity name
|
||||||
size: "full"
|
size: "full" # model dimensionality
|
||||||
resp_levels: 8
|
resp_levels: 8 # RVQ levels this model targets
|
||||||
prom_levels: 8
|
prom_levels: 8 # should always be the above
|
||||||
tasks: 8
|
tasks: 8 # tasks this model can attend to, only tts is supported at the moment
|
||||||
langs: 2
|
langs: 2 # languages this model supports, semi-unused at the moment
|
||||||
tones: 1
|
tones: 1 # tones this model supports, currently unused
|
||||||
arch_type: llama
|
arch_type: llama # underlying LLM arch to use, currently focusing on llama
|
||||||
training: True
|
training: True # signals this model is to be trained
|
||||||
version: 5
|
version: 5 # helps keep backwards compatibility for when I add new things to the model
|
||||||
attention: flash_attention_2
|
attention: auto # attention mechanism to use, "auto" for safety
|
||||||
dropout: 0.1
|
dropout: 0.1 # percentage of the model to disable during training
|
||||||
experimental: False
|
|
||||||
|
|
||||||
|
# factors for split loss values, remove to have a unified loss calculation
|
||||||
loss_factors:
|
loss_factors:
|
||||||
text: 0.1
|
text: 0.1 # text phoneme portion of the sequence
|
||||||
prom: 0.0
|
prom: 0.0 # input prompt portion of the sequence
|
||||||
resp: 1.0
|
resp: 1.0 # output audio portin of the sequence
|
||||||
|
|
||||||
|
# experimental settings
|
||||||
|
experimental:
|
||||||
|
hf: False # uses vall_e.models.experimental, a wrapper around a HF model that could technically be used for non-pytorch backends later
|
||||||
|
interleave: False # interleaves RVQ levels, only works with above for now
|
||||||
|
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
|
||||||
|
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
|
||||||
|
|
||||||
hyperparameters:
|
hyperparameters:
|
||||||
autotune: False
|
autotune: False
|
||||||
|
|
|
@ -201,8 +201,10 @@ class ModelExperimentalSettings:
|
||||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
|
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head
|
||||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||||
audio_embeddings_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||||
|
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||||
|
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
@ -225,9 +227,6 @@ class Model:
|
||||||
loss_factors: dict = field(default_factory=lambda: {})
|
loss_factors: dict = field(default_factory=lambda: {})
|
||||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||||
|
|
||||||
p_rvq_levels: str = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
|
||||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
|
||||||
|
|
||||||
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
|
experimental: dict | ModelExperimentalSettings | None = None # experimental settings
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
|
|
@ -29,26 +29,10 @@ class AR_NAR(Base):
|
||||||
return self.config.capabilities
|
return self.config.capabilities
|
||||||
return cfg.model.capabilities
|
return cfg.model.capabilities
|
||||||
|
|
||||||
@property
|
|
||||||
def quant_level_range(self) -> list[int]:
|
|
||||||
if hasattr(self, "config") and self.config.rvq_level_range:
|
|
||||||
return self.config.rvq_level_range
|
|
||||||
return [ 0 if self.causal else 1, self.n_resp_levels ]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def causal(self):
|
def causal(self):
|
||||||
return "ar" in self.capabilities
|
return "ar" in self.capabilities
|
||||||
|
|
||||||
@property
|
|
||||||
def norm_type(self):
|
|
||||||
return "ln" # if self.n_resp_levels == 1 else "adaln"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def arch_type(self) -> str:
|
|
||||||
if hasattr(self, "config") and self.config:
|
|
||||||
return self.config.arch_type
|
|
||||||
return cfg.model.arch_type
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_prom_levels(self) -> int:
|
def n_prom_levels(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -73,12 +57,6 @@ class AR_NAR(Base):
|
||||||
return self.config.tasks
|
return self.config.tasks
|
||||||
return cfg.model.tasks
|
return cfg.model.tasks
|
||||||
|
|
||||||
@property
|
|
||||||
def p_rvq_levels(self) -> int:
|
|
||||||
if hasattr(self, "config") and self.config:
|
|
||||||
return self.config.p_rvq_levels
|
|
||||||
return cfg.model.p_rvq_levels
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_langs(self) -> int:
|
def n_langs(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -159,16 +137,18 @@ class AR_NAR(Base):
|
||||||
def sample_task():
|
def sample_task():
|
||||||
return "tts"
|
return "tts"
|
||||||
|
|
||||||
|
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||||
|
|
||||||
# generate task list to train against
|
# generate task list to train against
|
||||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||||
|
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.quant_level_range
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||||
|
|
||||||
if self.p_rvq_levels == "equal":
|
if p_rvq_levels == "equal":
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||||
else: # if self.p_rvq_levels == "auto":
|
else: # if p_rvq_levels == "auto":
|
||||||
# makes higher levels less likely
|
# makes higher levels less likely
|
||||||
def generate( lo=0, hi=8 ):
|
def generate( lo=0, hi=8 ):
|
||||||
index = lo
|
index = lo
|
||||||
|
|
|
@ -282,14 +282,6 @@ class Base(nn.Module):
|
||||||
def causal(self) -> bool:
|
def causal(self) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
|
||||||
def arch_type(self) -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def norm_type(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_prom_levels(self) -> int:
|
def n_prom_levels(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -377,6 +369,9 @@ class Base(nn.Module):
|
||||||
self.l_padding = l_padding
|
self.l_padding = l_padding
|
||||||
|
|
||||||
n_prom_tokens = n_audio_tokens
|
n_prom_tokens = n_audio_tokens
|
||||||
|
arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||||
|
|
||||||
|
self.arch_type = arch_type
|
||||||
|
|
||||||
# check if requested arch is unavailable
|
# check if requested arch is unavailable
|
||||||
if self.arch_type in ERROR_ARCHES:
|
if self.arch_type in ERROR_ARCHES:
|
||||||
|
@ -392,7 +387,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
|
||||||
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
|
||||||
audio_embeddings_mode = self.config.experimental.audio_embeddings_mode if self.config is not None else ""
|
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
|
@ -420,12 +415,12 @@ class Base(nn.Module):
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
sums=audio_embedding_sums,
|
sums=audio_embedding_sums,
|
||||||
external_mode=audio_embeddings_mode,
|
external_mode=audio_embedding_mode,
|
||||||
)
|
)
|
||||||
self.resps_emb = AudioEmbedding(
|
self.resps_emb = AudioEmbedding(
|
||||||
l_tokens, d_model,
|
l_tokens, d_model,
|
||||||
sums=audio_embedding_sums,
|
sums=audio_embedding_sums,
|
||||||
external_mode=audio_embeddings_mode,
|
external_mode=audio_embedding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# useless since I actually removed using these with the input processing overhaul...
|
# useless since I actually removed using these with the input processing overhaul...
|
||||||
|
@ -471,7 +466,7 @@ class Base(nn.Module):
|
||||||
n_heads=n_heads,
|
n_heads=n_heads,
|
||||||
p_dropout=p_dropout if training else 0.0,
|
p_dropout=p_dropout if training else 0.0,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
norm_type=self.norm_type,
|
norm_type="ln", # adaln
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
) for _ in range(n_layers) ])
|
) for _ in range(n_layers) ])
|
||||||
elif self.arch_type in ["mistral", "mixtral"]:
|
elif self.arch_type in ["mistral", "mixtral"]:
|
||||||
|
|
|
@ -27,26 +27,10 @@ class NAR(Base):
|
||||||
return self.config.capabilities
|
return self.config.capabilities
|
||||||
return cfg.model.capabilities
|
return cfg.model.capabilities
|
||||||
|
|
||||||
@property
|
|
||||||
def quant_level_range(self) -> list[int]:
|
|
||||||
if hasattr(self, "config") and self.config.rvq_level_range:
|
|
||||||
return self.config.rvq_level_range
|
|
||||||
return [ 0 if self.causal else 1, self.n_resp_levels ]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def causal(self):
|
def causal(self):
|
||||||
return "len" in self.capabilities
|
return "len" in self.capabilities
|
||||||
|
|
||||||
@property
|
|
||||||
def norm_type(self):
|
|
||||||
return "ln" # if self.n_resp_levels == 1 else "adaln"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def arch_type(self) -> str:
|
|
||||||
if hasattr(self, "config") and self.config:
|
|
||||||
return self.config.arch_type
|
|
||||||
return cfg.model.arch_type
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_prom_levels(self) -> int:
|
def n_prom_levels(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -71,12 +55,6 @@ class NAR(Base):
|
||||||
return self.config.tasks
|
return self.config.tasks
|
||||||
return cfg.model.tasks
|
return cfg.model.tasks
|
||||||
|
|
||||||
@property
|
|
||||||
def p_rvq_levels(self) -> int:
|
|
||||||
if hasattr(self, "config") and self.config:
|
|
||||||
return self.config.p_rvq_levels
|
|
||||||
return cfg.model.p_rvq_levels
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_langs(self) -> int:
|
def n_langs(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -159,12 +137,14 @@ class NAR(Base):
|
||||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||||
|
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.quant_level_range
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||||
|
|
||||||
if self.p_rvq_levels == "equal":
|
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
||||||
|
|
||||||
|
if p_rvq_levels == "equal":
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||||
else: # if self.p_rvq_levels == "auto":
|
else: # if p_rvq_levels == "auto":
|
||||||
# makes higher levels less likely
|
# makes higher levels less likely
|
||||||
def generate( lo=0, hi=8 ):
|
def generate( lo=0, hi=8 ):
|
||||||
index = lo
|
index = lo
|
||||||
|
|
Loading…
Reference in New Issue
Block a user