From 387358bc8adc3641083ee56d3836788d7d7cb0f6 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 31 Jul 2024 20:35:09 -0500 Subject: [PATCH] fixes for the NAR-len model, and documentation some config options, and a better way to handle resizing modules on state_dict load --- vall_e/config.py | 157 +++++++++++++++++++------------------ vall_e/engines/__init__.py | 20 +++-- vall_e/models/base.py | 12 ++- vall_e/models/nar.py | 2 + 4 files changed, 104 insertions(+), 87 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 1430368..fca5917 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -31,7 +31,7 @@ def set_seed(seed=None): @dataclass() class BaseConfig: - yaml_path: str | None = None + yaml_path: str | None = None # path passed in through --yaml @property def cfg_path(self): @@ -124,38 +124,38 @@ class BaseConfig: @dataclass() class Dataset: - training: list[Path] = field(default_factory=lambda: []) - validation: list[Path] = field(default_factory=lambda: []) - noise: list[Path] = field(default_factory=lambda: []) + training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset + validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset + noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset - temp: list[Path] = field(default_factory=lambda: []) + temp: list[Path] = field(default_factory=lambda: []) # for when I need to yank things out of a training dataset without cutting it out - speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" - speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" + speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path + speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups - hdf5_name: str = "data.h5" - use_hdf5: bool = False - hdf5_flag: str = "a" - use_metadata: bool = False + hdf5_name: str = "data.h5" # file name to load the HDF5 dataset + use_hdf5: bool = False # whether to load from an HDF5 dataset + hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways + use_metadata: bool = False # use genretaed metadata to aid in dataset loading - validate: bool = True - workers: int = 8 - cache: bool = True + validate: bool = True # validate each utterance on wheter it can be included based on duration range caps + workers: int = 8 # number of dataloader workers to spawn + cache: bool = True # use diskcache to cache the dataset - phones_range: list[int] = field(default_factory=lambda: [4, 256]) - duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) - prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) - min_utterances: int = 2 + phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset + duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset + prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be + min_utterances: int = 2 # minimum number of utterances a speaker can have - random_utterance: float = 1.0 - max_prompts: int = 3 + random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt + max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training prompt_duration: float | None = None # legacy - max_resps: int = 1 - p_resp_append: float = 1.0 + max_resps: int = 1 # number of samples to target for training + p_resp_append: float = 1.0 # probability to append another sample to the training target sample_type: str = "path" # path | speaker sample_order: str = "interleaved" # duration @@ -163,11 +163,11 @@ class Dataset: # for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough sample_shuffle: bool = True # - tasks_list: list[str] = field(default_factory=lambda: ["tts"]) + tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method noise_scale: float = 0.25 # scaling noise value - inject_noise_in_prom: bool = False + inject_noise_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things _frames_per_second: int = 0 # allows setting your own hint @@ -369,7 +369,7 @@ class LoRA: alpha: int = 128 # rank for the LoRA training: bool = True # embeddings: bool = False # train the embedding too - parametrize: bool = False # + parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA @property @@ -385,42 +385,42 @@ class LoRA: @dataclass() class Hyperparameters: - batch_size: int = 8 - gradient_accumulation_steps: int = 32 - gradient_clipping: int | float = 100 + batch_size: int = 8 # number of samples per training batch + gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating + gradient_clipping: int | float = 10 # largest size a gradient norm can be - optimizer: str = "Adamw" # should be 'Prodigyopt" now + optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt - warmup_steps: int = 0 + warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed - scheduler: str = "" + scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter scheduler_type: str = "" # deprecated scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config - autotune: bool = False + autotune: bool = False # to do deepspeed's autotuning autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config - torch_optimizer: bool = False - torch_scheduler: bool = False + torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied + torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied @dataclass() class Evaluation: - batch_size: int = 64 - frequency: int = 250 - size: int = 64 - - steps: int = 500 - ar_temperature: float = 1.0 - nar_temperature: float = 0.0 - nar_levels: int = 0 + batch_size: int = 64 # number of samples per batch during eval / val + frequency: int = 250 # do eval / val every X iterations + size: int = 64 # number of samples to generate during eval / val - load_disabled_engines: bool = True + steps: int = 500 + ar_temperature: float = 1.0 # AR temp for inferencing + nar_temperature: float = 0.0 # NAR temp for inferencing + nar_levels: int = 0 # maximum NAR levels to use for inferencing + + load_disabled_engines: bool = True # see the other load_disabled_engines @dataclass() class DeepSpeed: - zero_optimization_level: int = 0 + zero_optimization_level: int = 0 # doesn't seem to work use_compression_training: bool = False # cope compression_bits: int = 8 # cope inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead @@ -576,45 +576,46 @@ class DeepSpeed: @dataclass() class Trainer: - iterations: int = 100_000 + iterations: int = 1_000_000 # maximum iterations to train - save_tag: str = "step" - load_tag: str | None = None + save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count + load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name - save_on_oom: bool = True - save_on_quit: bool = True + save_on_oom: bool = True # save if an OOM error is raised + save_on_quit: bool = True # save when quitting training - export_on_save: bool = False - export_on_quit: bool = False + export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint + export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training - save_frequency: int = 100 + save_frequency: int = 100 # frequency to save every X iterations - keep_last_checkpoints: int = 0 + keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones - load_state_dict: bool = False - load_states: bool = True - strict_loading: bool = False - load_module_only: bool = False - restart_step_count: bool = False + load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists + load_states: bool = True # + strict_loading: bool = False # sets strict_loading=True when loading the state dict + load_module_only: bool = False # + restart_step_count: bool = False # clears the training stats when loading a checkpoint + resize_modules: bool = False # automatically resizes activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing - gradient_checkpointing: bool = True + gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training - aggressive_optimizations: bool = False - check_for_oom: bool = True - gc_mode: str | None = None - load_disabled_engines: bool = False + aggressive_optimizations: bool = False # deprecated + check_for_oom: bool = True # checks for OOMs thrown during forward/backwards + gc_mode: str | None = None # deprecated, but marks when to do GC + load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation - weight_dtype: str = "float16" + weight_dtype: str = "float16" # dtype to have the model under - amp: bool = False - ddp: bool = False + amp: bool = False # automatic mixed precision + ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested - load_webui: bool = False - no_logger: bool = False + load_webui: bool = False # not working, but loads the web UI to allow inferencing during training + no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet - backend: str = "local" - deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) + backend: str = "local" # training backend to use. currently supports "local" | "deepspeed" + deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings @cached_property def dtype(self): @@ -638,9 +639,9 @@ class Trainer: @dataclass() class Inference: - backend: str = "local" - weight_dtype: str = "float32" - amp: bool = False + backend: str = "local" # backend to use when inferencing + weight_dtype: str = "float32" # dtype to load the model under + amp: bool = False # automatic mixed precision during inferencing normalize: bool = False # do NOT enable this unless you know exactly what you're doing @@ -681,7 +682,7 @@ class Optimizations: @dataclass() class Config(BaseConfig): - device: str = "cuda" + device: str = "cuda" # target device mode: str = "training" # "inferencing" experimental: bool = False # Debug flag, unused now @@ -695,13 +696,13 @@ class Config(BaseConfig): bitsandbytes: dict | list | None = None # deprecated optimizations: Optimizations = field(default_factory=lambda: Optimizations) - tokenizer: str | None = None - tokenizer_path: str = "./tokenizer.json" + tokenizer: str | None = None # tokenizer class + tokenizer_path: str = "./tokenizer.json" # tokenizer path - sample_rate: int = 24_000 + sample_rate: int = 24_000 # sample rate the model expects variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss - audio_backend: str = "vocos" + audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac"" @property def model(self): diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 6d4be71..9aaf635 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -178,11 +178,19 @@ def load_engines(training=True): for k in erase: del state[k] - # resize embeddings - if "text_emb.weight" in state: - state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens ) - if "rvq_l_emb.weight" in state: - state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels ) + # resize modules if I'm doing experiments and can't be assed to manually trim things + if cfg.trainer.resize_modules: + uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0 + keys = [ + ("text_emb.weight", model.config.text_tokens ), + ("rvq_l_emb.weight", model.config.resp_levels ), + ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), + ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ), + ] + for k, tokens in keys: + state[k] = ml.resize_weight( state[k], tokens ) model.load_state_dict(state, strict=cfg.trainer.strict_loading) @@ -227,4 +235,4 @@ def load_engines(training=True): for name, engine in engines.items(): engine.freeze(freeze_all=False) - return engines \ No newline at end of file + return engines diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8d6187c..e99d0bf 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -250,7 +250,13 @@ class AudioClassifier(nn.Module): xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ] # pad if needed - xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ] + max_size = max([ x.shape[-1] for x in xi ]) + xi = [ + #x if l == 0 else + x if x.shape[-1] == max_size else + torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 ) + for x, l in zip(xi, levels) + ] return torch.stack( xi ) class Metrics(nn.Module): @@ -1060,7 +1066,7 @@ class Base(nn.Module): # shamelessly grabbed from modeling_llama.py ids = mask.long().cumsum(-1) - 1 ids.masked_fill_( mask == 0, 1 ) - + # there's a better way if not self.unified_position_ids: x_list = [] @@ -1074,7 +1080,7 @@ class Base(nn.Module): if not isinstance(input, torch.Tensor): return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1 - return input.shape[0] + (0 if name == "resp" else 1) + return input.shape[0] + (0 if name in ["resp", "len"] else 1) for batch_index, batch_input in enumerate(inputs): batch = torch.cat( [ diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index c231c20..67d9d75 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -97,6 +97,8 @@ class NAR(Base): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, + training: bool | None = None, + max_steps: int = 1000, max_levels: int = 0, max_resp_context: int = -1,