fixes for the NAR-len model, and documentation some config options, and a better way to handle resizing modules on state_dict load
This commit is contained in:
parent
52d13b321f
commit
387358bc8a
153
vall_e/config.py
153
vall_e/config.py
|
@ -31,7 +31,7 @@ def set_seed(seed=None):
|
|||
|
||||
@dataclass()
|
||||
class BaseConfig:
|
||||
yaml_path: str | None = None
|
||||
yaml_path: str | None = None # path passed in through --yaml
|
||||
|
||||
@property
|
||||
def cfg_path(self):
|
||||
|
@ -124,38 +124,38 @@ class BaseConfig:
|
|||
|
||||
@dataclass()
|
||||
class Dataset:
|
||||
training: list[Path] = field(default_factory=lambda: [])
|
||||
validation: list[Path] = field(default_factory=lambda: [])
|
||||
noise: list[Path] = field(default_factory=lambda: [])
|
||||
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
|
||||
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
|
||||
noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset
|
||||
|
||||
temp: list[Path] = field(default_factory=lambda: [])
|
||||
temp: list[Path] = field(default_factory=lambda: []) # for when I need to yank things out of a training dataset without cutting it out
|
||||
|
||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'"
|
||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path
|
||||
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path
|
||||
|
||||
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
|
||||
|
||||
hdf5_name: str = "data.h5"
|
||||
use_hdf5: bool = False
|
||||
hdf5_flag: str = "a"
|
||||
use_metadata: bool = False
|
||||
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
|
||||
use_hdf5: bool = False # whether to load from an HDF5 dataset
|
||||
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
|
||||
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
|
||||
|
||||
validate: bool = True
|
||||
workers: int = 8
|
||||
cache: bool = True
|
||||
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
||||
workers: int = 8 # number of dataloader workers to spawn
|
||||
cache: bool = True # use diskcache to cache the dataset
|
||||
|
||||
phones_range: list[int] = field(default_factory=lambda: [4, 256])
|
||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
|
||||
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0])
|
||||
min_utterances: int = 2
|
||||
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
|
||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
||||
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
|
||||
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
||||
|
||||
random_utterance: float = 1.0
|
||||
max_prompts: int = 3
|
||||
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
|
||||
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
|
||||
|
||||
prompt_duration: float | None = None # legacy
|
||||
|
||||
max_resps: int = 1
|
||||
p_resp_append: float = 1.0
|
||||
max_resps: int = 1 # number of samples to target for training
|
||||
p_resp_append: float = 1.0 # probability to append another sample to the training target
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
sample_order: str = "interleaved" # duration
|
||||
|
@ -163,11 +163,11 @@ class Dataset:
|
|||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||
sample_shuffle: bool = True #
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||
noise_scale: float = 0.25 # scaling noise value
|
||||
inject_noise_in_prom: bool = False
|
||||
inject_noise_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
|
||||
|
||||
_frames_per_second: int = 0 # allows setting your own hint
|
||||
|
||||
|
@ -369,7 +369,7 @@ class LoRA:
|
|||
alpha: int = 128 # rank for the LoRA
|
||||
training: bool = True #
|
||||
embeddings: bool = False # train the embedding too
|
||||
parametrize: bool = False #
|
||||
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
|
||||
@property
|
||||
|
@ -385,42 +385,42 @@ class LoRA:
|
|||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
batch_size: int = 8
|
||||
gradient_accumulation_steps: int = 32
|
||||
gradient_clipping: int | float = 100
|
||||
batch_size: int = 8 # number of samples per training batch
|
||||
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
|
||||
gradient_clipping: int | float = 10 # largest size a gradient norm can be
|
||||
|
||||
optimizer: str = "Adamw" # should be 'Prodigyopt" now
|
||||
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
|
||||
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
||||
warmup_steps: int = 0
|
||||
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
|
||||
|
||||
scheduler: str = ""
|
||||
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
|
||||
scheduler_type: str = "" # deprecated
|
||||
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
autotune: bool = False
|
||||
autotune: bool = False # to do deepspeed's autotuning
|
||||
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||
|
||||
torch_optimizer: bool = False
|
||||
torch_scheduler: bool = False
|
||||
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
||||
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
||||
|
||||
@dataclass()
|
||||
class Evaluation:
|
||||
batch_size: int = 64
|
||||
frequency: int = 250
|
||||
size: int = 64
|
||||
batch_size: int = 64 # number of samples per batch during eval / val
|
||||
frequency: int = 250 # do eval / val every X iterations
|
||||
size: int = 64 # number of samples to generate during eval / val
|
||||
|
||||
steps: int = 500
|
||||
ar_temperature: float = 1.0
|
||||
nar_temperature: float = 0.0
|
||||
nar_levels: int = 0
|
||||
ar_temperature: float = 1.0 # AR temp for inferencing
|
||||
nar_temperature: float = 0.0 # NAR temp for inferencing
|
||||
nar_levels: int = 0 # maximum NAR levels to use for inferencing
|
||||
|
||||
load_disabled_engines: bool = True
|
||||
load_disabled_engines: bool = True # see the other load_disabled_engines
|
||||
|
||||
@dataclass()
|
||||
class DeepSpeed:
|
||||
zero_optimization_level: int = 0
|
||||
zero_optimization_level: int = 0 # doesn't seem to work
|
||||
use_compression_training: bool = False # cope
|
||||
compression_bits: int = 8 # cope
|
||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||
|
@ -576,45 +576,46 @@ class DeepSpeed:
|
|||
|
||||
@dataclass()
|
||||
class Trainer:
|
||||
iterations: int = 100_000
|
||||
iterations: int = 1_000_000 # maximum iterations to train
|
||||
|
||||
save_tag: str = "step"
|
||||
load_tag: str | None = None
|
||||
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
|
||||
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
|
||||
|
||||
save_on_oom: bool = True
|
||||
save_on_quit: bool = True
|
||||
save_on_oom: bool = True # save if an OOM error is raised
|
||||
save_on_quit: bool = True # save when quitting training
|
||||
|
||||
export_on_save: bool = False
|
||||
export_on_quit: bool = False
|
||||
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
|
||||
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
|
||||
|
||||
save_frequency: int = 100
|
||||
save_frequency: int = 100 # frequency to save every X iterations
|
||||
|
||||
keep_last_checkpoints: int = 0
|
||||
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
|
||||
|
||||
load_state_dict: bool = False
|
||||
load_states: bool = True
|
||||
strict_loading: bool = False
|
||||
load_module_only: bool = False
|
||||
restart_step_count: bool = False
|
||||
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
|
||||
load_states: bool = True #
|
||||
strict_loading: bool = False # sets strict_loading=True when loading the state dict
|
||||
load_module_only: bool = False #
|
||||
restart_step_count: bool = False # clears the training stats when loading a checkpoint
|
||||
resize_modules: bool = False # automatically resizes
|
||||
|
||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||
gradient_checkpointing: bool = True
|
||||
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||
|
||||
aggressive_optimizations: bool = False
|
||||
check_for_oom: bool = True
|
||||
gc_mode: str | None = None
|
||||
load_disabled_engines: bool = False
|
||||
aggressive_optimizations: bool = False # deprecated
|
||||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
|
||||
|
||||
weight_dtype: str = "float16"
|
||||
weight_dtype: str = "float16" # dtype to have the model under
|
||||
|
||||
amp: bool = False
|
||||
ddp: bool = False
|
||||
amp: bool = False # automatic mixed precision
|
||||
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
||||
|
||||
load_webui: bool = False
|
||||
no_logger: bool = False
|
||||
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
||||
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
||||
|
||||
backend: str = "local"
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
|
@ -638,9 +639,9 @@ class Trainer:
|
|||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
backend: str = "local"
|
||||
weight_dtype: str = "float32"
|
||||
amp: bool = False
|
||||
backend: str = "local" # backend to use when inferencing
|
||||
weight_dtype: str = "float32" # dtype to load the model under
|
||||
amp: bool = False # automatic mixed precision during inferencing
|
||||
|
||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||
|
||||
|
@ -681,7 +682,7 @@ class Optimizations:
|
|||
|
||||
@dataclass()
|
||||
class Config(BaseConfig):
|
||||
device: str = "cuda"
|
||||
device: str = "cuda" # target device
|
||||
mode: str = "training" # "inferencing"
|
||||
experimental: bool = False # Debug flag, unused now
|
||||
|
||||
|
@ -695,13 +696,13 @@ class Config(BaseConfig):
|
|||
bitsandbytes: dict | list | None = None # deprecated
|
||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||
|
||||
tokenizer: str | None = None
|
||||
tokenizer_path: str = "./tokenizer.json"
|
||||
tokenizer: str | None = None # tokenizer class
|
||||
tokenizer_path: str = "./tokenizer.json" # tokenizer path
|
||||
|
||||
sample_rate: int = 24_000
|
||||
sample_rate: int = 24_000 # sample rate the model expects
|
||||
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
||||
|
||||
audio_backend: str = "vocos"
|
||||
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
|
|
|
@ -178,11 +178,19 @@ def load_engines(training=True):
|
|||
for k in erase:
|
||||
del state[k]
|
||||
|
||||
# resize embeddings
|
||||
if "text_emb.weight" in state:
|
||||
state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens )
|
||||
if "rvq_l_emb.weight" in state:
|
||||
state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels )
|
||||
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||
if cfg.trainer.resize_modules:
|
||||
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
|
||||
keys = [
|
||||
("text_emb.weight", model.config.text_tokens ),
|
||||
("rvq_l_emb.weight", model.config.resp_levels ),
|
||||
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
||||
("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||
("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ),
|
||||
]
|
||||
for k, tokens in keys:
|
||||
state[k] = ml.resize_weight( state[k], tokens )
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
|
|
|
@ -250,7 +250,13 @@ class AudioClassifier(nn.Module):
|
|||
|
||||
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||
# pad if needed
|
||||
xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ]
|
||||
max_size = max([ x.shape[-1] for x in xi ])
|
||||
xi = [
|
||||
#x if l == 0 else
|
||||
x if x.shape[-1] == max_size else
|
||||
torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||
for x, l in zip(xi, levels)
|
||||
]
|
||||
return torch.stack( xi )
|
||||
|
||||
class Metrics(nn.Module):
|
||||
|
@ -1074,7 +1080,7 @@ class Base(nn.Module):
|
|||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||
|
||||
return input.shape[0] + (0 if name == "resp" else 1)
|
||||
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
||||
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = torch.cat( [
|
||||
|
|
|
@ -97,6 +97,8 @@ class NAR(Base):
|
|||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | None = None,
|
||||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 0,
|
||||
max_resp_context: int = -1,
|
||||
|
|
Loading…
Reference in New Issue
Block a user