adjusted how i want to pass eval kwargs

This commit is contained in:
mrq 2024-10-25 20:38:09 -05:00
parent 92e6bff6dc
commit a96f5aee32
4 changed files with 72 additions and 37 deletions

View File

@ -6,9 +6,8 @@ models:
- name: "ar+nar" # vanity name
size: "full" # model dimensionality
resp_levels: 8 # RVQ levels this model targets
prom_levels: 8 # should always be the above
tasks: 8 # tasks this model can attend to, only tts is guaranteed results at the moment
langs: 2 # languages this model supports, semi-unused at the moment
tasks: 9 # tasks this model can attend to, only tts is guaranteed results at the moment
langs: 4 # languages this model supports
tones: 1 # tones this model supports, currently unused
arch_type: llama # underlying LLM arch to use, currently focusing on llama
training: True # signals this model is to be trained
@ -22,14 +21,22 @@ models:
prom: 0.5 # input prompt portion of the sequence
resp: 1.0 # output audio portin of the sequence
capabilities: ["ar", "nar"] # macro-tasks this model can perform
# experimental settings
experimental:
hf: False # uses vall_e.models.experimental, a wrapper around a HF model that could technically be used for non-pytorch backends later
interleave: False # interleaves RVQ levels, only works with above for now
audio_embedding_mode: "" # "" | "inclusive" | "exclusive", whether to utilize the audio backend's embeddings with the input embeddings
audio_embedding_sums: False # whether the input embeddings include all prior RVQ levels (sums) or only the current one, further experimentation is needed to see if this matters
p_rvq_levels: "auto" # "equal" | "auto", sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one
rvq_levels_p: "auto" # "equal" | "auto" | list[int], sets probabilities of which RVQ level to select during training, auto will have the next RVQ level half as likely as the previous one
audio_embedding_sums: True # whether the input embeddings include all prior RVQ levels (sums) or only the current one (further experimentation is needed to see if this matters)
unified_position_ids: False # specifies whether or not position IDs should be continuous across the whole sequence (if True, naive behavior), or restart them at the next segment of the sequence (if False)
split_classifiers: True # use per-RVQ-level projection/output/classifiers for the model (further experimentation is needed to see if this matters)
# list of LoRA(s) to use
#loras:
#- name : "lora-shodan" # LoRA name to load from
# rank: 128 # parameter size per Linear
# alpha: 128 # "influence" value
# training: True #
# rvq_levels: [] # RVQ levels to activate the LoRA on, leave empty for all
# hyperparameter settings (could be relegated to trainer settings)
hyperparameters:
@ -59,16 +66,14 @@ evaluation:
batch_size: 8 # batch size for evaluation / validation pass
frequency: 5000 # how often to perform eval during training
size: 8 # total samples to get for eval
steps: 500 # how many AR steps to perform
ar_temperature: 0.95 # temperature for AR sampling
nar_temperature: 0.25 # temperature for NAR sampling
load_disabled_engines: True # deprecated
# arguments to pass for the AR/NAR (matches arguments passed through vall_e.inference)
kwargs:
max_steps: 500 # how many AR steps to perform
ar_temp: 0.95 # temperature for AR sampling
nar_temp: 0.25 # temperature for NAR sampling
trainer:
#no_logger: True # deprecated, because the logger should always work now
ddp: False # whether to wrap the model with DDP, should automatically be set
#check_for_oom: False # wrap forward/backwards in a try/catch block and gracefully handles OOM conditions
iterations: 1_000_000 # how many total iterations to train before terminating, should just have this as 0 by default to not auto-terminiate
save_tag: step # tag name to save checkpoints under
@ -82,26 +87,27 @@ trainer:
gradient_checkpointing: True # gradient checkpointing to save VRAM at the cost of some performance throughput
strict_loading: False # strict state dict loading (set to False if you're going to change some model settings)
resize_modules: True # automatically resize core modules from the state dict to match
#check_for_oom: False # wrap forward/backwards in a try/catch block and gracefully handles OOM conditions
#load_state_dict: True # load the state dict from fp32.pth instead of a checkpoint, should automagically be done
#load_tag: "9500" # specific tag to load from (instead of having to edit latest)
#load_states: False # flag to load optimizer / scheduler states or not
#restart_step_count: True # clear the trainer stats
gc_mode: None # "global_step" # flag to call GC at specific points, seems overkill now
# gc_mode: None # "global_step" # flag to call GC at specific points, seems overkill now
weight_dtype: float32 # float32 | float16 | bfloat16, dtype for the model to load under
amp: False # mixed precision during training
weight_dtype: float16 # float32 | float16 | bfloat16, dtype for the model to load under
amp: True # mixed precision during training
backend: deepspeed # deepspeed | local, training backend to use
# deepspeed specific settings
deepspeed:
inferencing: True # use deepspeed inference wrapper for inferencing, should be relegated under inference
amp: False # use deepspeed's AMP instead (requires nvidia/apex installed)
zero_optimization_level: 0 # ZeRO optimization level to use
use_compression_training: False # compression training (seems useless almost always)
amp: False # use deepspeed's AMP instead (requires nvidia/apex installed)
load_webui: False # initialize the web UI during training (the goal is to let you inference during training, but I never found a good way to go about it)
# inferencing settings
@ -145,20 +151,22 @@ dataset:
duration_range: [3.0, 5.0] # allowed sample duration in the dataset
random_utterance: 1.0 # I don't remember desu
max_prompts: 1 # how many prompts to sample to create the input prompt
#prompt_duration: 3.0 # sugar for the below
prompt_duration_range: [3.0, 3.0] # range of durations for the input prompt to be trimmed under
# deprecated
max_resps: 1 # how many random response utterances to sample for the sample
p_resp_append: 0.25 # probability to append additional utterances for the above
prompt_max_samples: 1 # maximum prompts to sample for the input prompt during training
prompt_duration_range: [3.0, 6.0] # duration range for the input prompt during training
prompt_similar_p: 1.0 # odds to instead use a similar utterance instead of a random sample (1 to always do, 0 to never do)
# not used
resps_max_samples: 1 # maximum output utterances to sample for the output during training
resps_append_p: 0.0 # odds to append another utterance to the output utterance sample
sample_type: path # path | speaker | group, type to sample the paths from (by path, speaker, or group)
sample_order: duration # duration | anything else, method of ordering the paths (duration is by duration, any other value will interleave reorder)
sample_max_duration_batch: 0 # used when above = duration, 120 seconds per batch at 12GiB of VRAM works
sample_shuffle: False # shuffle indices in the dataloader (avoid using with sample_order: duration and sample_max_duration_batch: 0)
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"], unused at the moment, but will determine which tasks to use
retokenize_text: False # do not rely on AOT'd tokens from the dataset, instead tokenize JIT (in case you botch your tokenizer during dataset preparation and don't want to recreate it)
tasks_list: [ "tts", "stt" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "stt" ], determines which tasks to randomly pick for a sample
training: [] # paths for the training dataset
validation: [] # paths for the validation dataset

View File

@ -427,8 +427,36 @@ class Evaluation:
batch_size: int = 64 # number of samples per batch during eval / val
frequency: int = 250 # do eval / val every X iterations
size: int = 64 # number of samples to generate during eval / val
ar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
nar_kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
kwargs: dict = field(default_factory=lambda: {}) # inferencing kwargs
# necessary in order to make it not confusing with requiring not-directyl exposed arguments passed to the model
@cached_property
def ar_kwargs( self ):
return dict(
max_steps=self.kwargs["max_ar_steps"],
sampling_temperature=self.kwargs["ar_temp"],
sampling_min_temperature=self.kwargs["min_ar_temp"],
sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"],
sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"],
sampling_length_penalty=self.kwargs["length_penalty"],
sampling_beam_width=self.kwargs["beam_width"],
sampling_mirostat_tau=self.kwargs["mirostat_tau"],
sampling_mirostat_eta=self.kwargs["mirostat_eta"],
sampling_dry_multiplier=self.kwargs["dry_multiplier"],
sampling_dry_base=self.kwargs["dry_base"],
sampling_dry_allowed_length=self.kwargs["dry_allowed_length"],
sampling_entropix=self.kwargs["entropix_sampling"],
)
@cached_property
def nar_kwargs( self ):
return dict(
max_levels=self.kwargs["max_nar_levels"],
sampling_temperature=self.kwargs["nar_temp"],
sampling_min_temperature=self.kwargs["min_nar_temp"],
sampling_top_p=self.kwargs["top_p"], sampling_top_k=self.kwargs["top_k"], sampling_min_p=self.kwargs["min_p"],
sampling_repetition_penalty=self.kwargs["repetition_penalty"], sampling_repetition_penalty_decay=self.kwargs["repetition_penalty_decay"],
)
@dataclass()
class DeepSpeed:
@ -648,8 +676,8 @@ class Trainer:
@dataclass()
class Inference:
backend: str = "local" # backend to use when inferencing
weight_dtype: str = "float32" # dtype to load the model under
amp: bool = False # automatic mixed precision during inferencing
weight_dtype: str = "float16" # dtype to load the model under
amp: bool = True # automatic mixed precision during inferencing
normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though

View File

@ -499,7 +499,7 @@ def example_usage():
available_tasks = ["tts", "stt"]
model = AR_NAR(**kwargs).to(device)
steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -132,7 +132,6 @@ def run_eval(engines, eval_name, dl, args=None):
for name in engines:
engine = engines[name]
base_kwargs = dict(
text_list=batch["text"],
proms_list=batch["proms"],