fixes for the NAR-len model, and documentation some config options, and a better way to handle resizing modules on state_dict load
This commit is contained in:
parent
52d13b321f
commit
387358bc8a
153
vall_e/config.py
153
vall_e/config.py
|
@ -31,7 +31,7 @@ def set_seed(seed=None):
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
yaml_path: str | None = None
|
yaml_path: str | None = None # path passed in through --yaml
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cfg_path(self):
|
def cfg_path(self):
|
||||||
|
@ -124,38 +124,38 @@ class BaseConfig:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Dataset:
|
class Dataset:
|
||||||
training: list[Path] = field(default_factory=lambda: [])
|
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
|
||||||
validation: list[Path] = field(default_factory=lambda: [])
|
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
|
||||||
noise: list[Path] = field(default_factory=lambda: [])
|
noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset
|
||||||
|
|
||||||
temp: list[Path] = field(default_factory=lambda: [])
|
temp: list[Path] = field(default_factory=lambda: []) # for when I need to yank things out of a training dataset without cutting it out
|
||||||
|
|
||||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path
|
||||||
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'"
|
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path
|
||||||
|
|
||||||
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
|
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
|
||||||
|
|
||||||
hdf5_name: str = "data.h5"
|
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
|
||||||
use_hdf5: bool = False
|
use_hdf5: bool = False # whether to load from an HDF5 dataset
|
||||||
hdf5_flag: str = "a"
|
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
|
||||||
use_metadata: bool = False
|
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
|
||||||
|
|
||||||
validate: bool = True
|
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
||||||
workers: int = 8
|
workers: int = 8 # number of dataloader workers to spawn
|
||||||
cache: bool = True
|
cache: bool = True # use diskcache to cache the dataset
|
||||||
|
|
||||||
phones_range: list[int] = field(default_factory=lambda: [4, 256])
|
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
|
||||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
|
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
||||||
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0])
|
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
|
||||||
min_utterances: int = 2
|
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
||||||
|
|
||||||
random_utterance: float = 1.0
|
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
|
||||||
max_prompts: int = 3
|
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
|
||||||
|
|
||||||
prompt_duration: float | None = None # legacy
|
prompt_duration: float | None = None # legacy
|
||||||
|
|
||||||
max_resps: int = 1
|
max_resps: int = 1 # number of samples to target for training
|
||||||
p_resp_append: float = 1.0
|
p_resp_append: float = 1.0 # probability to append another sample to the training target
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
sample_order: str = "interleaved" # duration
|
sample_order: str = "interleaved" # duration
|
||||||
|
@ -163,11 +163,11 @@ class Dataset:
|
||||||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||||
sample_shuffle: bool = True #
|
sample_shuffle: bool = True #
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
||||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||||
noise_scale: float = 0.25 # scaling noise value
|
noise_scale: float = 0.25 # scaling noise value
|
||||||
inject_noise_in_prom: bool = False
|
inject_noise_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
|
||||||
|
|
||||||
_frames_per_second: int = 0 # allows setting your own hint
|
_frames_per_second: int = 0 # allows setting your own hint
|
||||||
|
|
||||||
|
@ -369,7 +369,7 @@ class LoRA:
|
||||||
alpha: int = 128 # rank for the LoRA
|
alpha: int = 128 # rank for the LoRA
|
||||||
training: bool = True #
|
training: bool = True #
|
||||||
embeddings: bool = False # train the embedding too
|
embeddings: bool = False # train the embedding too
|
||||||
parametrize: bool = False #
|
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
|
||||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -385,42 +385,42 @@ class LoRA:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Hyperparameters:
|
class Hyperparameters:
|
||||||
batch_size: int = 8
|
batch_size: int = 8 # number of samples per training batch
|
||||||
gradient_accumulation_steps: int = 32
|
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
|
||||||
gradient_clipping: int | float = 100
|
gradient_clipping: int | float = 10 # largest size a gradient norm can be
|
||||||
|
|
||||||
optimizer: str = "Adamw" # should be 'Prodigyopt" now
|
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
|
||||||
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
|
||||||
warmup_steps: int = 0
|
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
|
||||||
|
|
||||||
scheduler: str = ""
|
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
|
||||||
scheduler_type: str = "" # deprecated
|
scheduler_type: str = "" # deprecated
|
||||||
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
autotune: bool = False
|
autotune: bool = False # to do deepspeed's autotuning
|
||||||
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
|
||||||
|
|
||||||
torch_optimizer: bool = False
|
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
||||||
torch_scheduler: bool = False
|
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Evaluation:
|
class Evaluation:
|
||||||
batch_size: int = 64
|
batch_size: int = 64 # number of samples per batch during eval / val
|
||||||
frequency: int = 250
|
frequency: int = 250 # do eval / val every X iterations
|
||||||
size: int = 64
|
size: int = 64 # number of samples to generate during eval / val
|
||||||
|
|
||||||
steps: int = 500
|
steps: int = 500
|
||||||
ar_temperature: float = 1.0
|
ar_temperature: float = 1.0 # AR temp for inferencing
|
||||||
nar_temperature: float = 0.0
|
nar_temperature: float = 0.0 # NAR temp for inferencing
|
||||||
nar_levels: int = 0
|
nar_levels: int = 0 # maximum NAR levels to use for inferencing
|
||||||
|
|
||||||
load_disabled_engines: bool = True
|
load_disabled_engines: bool = True # see the other load_disabled_engines
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class DeepSpeed:
|
class DeepSpeed:
|
||||||
zero_optimization_level: int = 0
|
zero_optimization_level: int = 0 # doesn't seem to work
|
||||||
use_compression_training: bool = False # cope
|
use_compression_training: bool = False # cope
|
||||||
compression_bits: int = 8 # cope
|
compression_bits: int = 8 # cope
|
||||||
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
|
||||||
|
@ -576,45 +576,46 @@ class DeepSpeed:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Trainer:
|
class Trainer:
|
||||||
iterations: int = 100_000
|
iterations: int = 1_000_000 # maximum iterations to train
|
||||||
|
|
||||||
save_tag: str = "step"
|
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
|
||||||
load_tag: str | None = None
|
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
|
||||||
|
|
||||||
save_on_oom: bool = True
|
save_on_oom: bool = True # save if an OOM error is raised
|
||||||
save_on_quit: bool = True
|
save_on_quit: bool = True # save when quitting training
|
||||||
|
|
||||||
export_on_save: bool = False
|
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
|
||||||
export_on_quit: bool = False
|
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
|
||||||
|
|
||||||
save_frequency: int = 100
|
save_frequency: int = 100 # frequency to save every X iterations
|
||||||
|
|
||||||
keep_last_checkpoints: int = 0
|
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
|
||||||
|
|
||||||
load_state_dict: bool = False
|
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
|
||||||
load_states: bool = True
|
load_states: bool = True #
|
||||||
strict_loading: bool = False
|
strict_loading: bool = False # sets strict_loading=True when loading the state dict
|
||||||
load_module_only: bool = False
|
load_module_only: bool = False #
|
||||||
restart_step_count: bool = False
|
restart_step_count: bool = False # clears the training stats when loading a checkpoint
|
||||||
|
resize_modules: bool = False # automatically resizes
|
||||||
|
|
||||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||||
gradient_checkpointing: bool = True
|
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||||
|
|
||||||
aggressive_optimizations: bool = False
|
aggressive_optimizations: bool = False # deprecated
|
||||||
check_for_oom: bool = True
|
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||||
gc_mode: str | None = None
|
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||||
load_disabled_engines: bool = False
|
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
|
||||||
|
|
||||||
weight_dtype: str = "float16"
|
weight_dtype: str = "float16" # dtype to have the model under
|
||||||
|
|
||||||
amp: bool = False
|
amp: bool = False # automatic mixed precision
|
||||||
ddp: bool = False
|
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
||||||
|
|
||||||
load_webui: bool = False
|
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
||||||
no_logger: bool = False
|
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
||||||
|
|
||||||
backend: str = "local"
|
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
|
||||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
|
@ -638,9 +639,9 @@ class Trainer:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
backend: str = "local"
|
backend: str = "local" # backend to use when inferencing
|
||||||
weight_dtype: str = "float32"
|
weight_dtype: str = "float32" # dtype to load the model under
|
||||||
amp: bool = False
|
amp: bool = False # automatic mixed precision during inferencing
|
||||||
|
|
||||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
||||||
|
|
||||||
|
@ -681,7 +682,7 @@ class Optimizations:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
device: str = "cuda"
|
device: str = "cuda" # target device
|
||||||
mode: str = "training" # "inferencing"
|
mode: str = "training" # "inferencing"
|
||||||
experimental: bool = False # Debug flag, unused now
|
experimental: bool = False # Debug flag, unused now
|
||||||
|
|
||||||
|
@ -695,13 +696,13 @@ class Config(BaseConfig):
|
||||||
bitsandbytes: dict | list | None = None # deprecated
|
bitsandbytes: dict | list | None = None # deprecated
|
||||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||||
|
|
||||||
tokenizer: str | None = None
|
tokenizer: str | None = None # tokenizer class
|
||||||
tokenizer_path: str = "./tokenizer.json"
|
tokenizer_path: str = "./tokenizer.json" # tokenizer path
|
||||||
|
|
||||||
sample_rate: int = 24_000
|
sample_rate: int = 24_000 # sample rate the model expects
|
||||||
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
|
||||||
|
|
||||||
audio_backend: str = "vocos"
|
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
|
|
|
@ -178,11 +178,19 @@ def load_engines(training=True):
|
||||||
for k in erase:
|
for k in erase:
|
||||||
del state[k]
|
del state[k]
|
||||||
|
|
||||||
# resize embeddings
|
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||||
if "text_emb.weight" in state:
|
if cfg.trainer.resize_modules:
|
||||||
state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens )
|
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
|
||||||
if "rvq_l_emb.weight" in state:
|
keys = [
|
||||||
state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels )
|
("text_emb.weight", model.config.text_tokens ),
|
||||||
|
("rvq_l_emb.weight", model.config.resp_levels ),
|
||||||
|
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||||
|
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
||||||
|
("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||||
|
("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ),
|
||||||
|
]
|
||||||
|
for k, tokens in keys:
|
||||||
|
state[k] = ml.resize_weight( state[k], tokens )
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
|
|
|
@ -250,7 +250,13 @@ class AudioClassifier(nn.Module):
|
||||||
|
|
||||||
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||||
# pad if needed
|
# pad if needed
|
||||||
xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ]
|
max_size = max([ x.shape[-1] for x in xi ])
|
||||||
|
xi = [
|
||||||
|
#x if l == 0 else
|
||||||
|
x if x.shape[-1] == max_size else
|
||||||
|
torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||||
|
for x, l in zip(xi, levels)
|
||||||
|
]
|
||||||
return torch.stack( xi )
|
return torch.stack( xi )
|
||||||
|
|
||||||
class Metrics(nn.Module):
|
class Metrics(nn.Module):
|
||||||
|
@ -1074,7 +1080,7 @@ class Base(nn.Module):
|
||||||
if not isinstance(input, torch.Tensor):
|
if not isinstance(input, torch.Tensor):
|
||||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||||
|
|
||||||
return input.shape[0] + (0 if name == "resp" else 1)
|
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
||||||
|
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = torch.cat( [
|
batch = torch.cat( [
|
||||||
|
|
|
@ -97,6 +97,8 @@ class NAR(Base):
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
training: bool | None = None,
|
||||||
|
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
max_levels: int = 0,
|
max_levels: int = 0,
|
||||||
max_resp_context: int = -1,
|
max_resp_context: int = -1,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user