fixes for the NAR-len model, and documentation some config options, and a better way to handle resizing modules on state_dict load

This commit is contained in:
mrq 2024-07-31 20:35:09 -05:00
parent 52d13b321f
commit 387358bc8a
4 changed files with 104 additions and 87 deletions

View File

@ -31,7 +31,7 @@ def set_seed(seed=None):
@dataclass()
class BaseConfig:
yaml_path: str | None = None
yaml_path: str | None = None # path passed in through --yaml
@property
def cfg_path(self):
@ -124,38 +124,38 @@ class BaseConfig:
@dataclass()
class Dataset:
training: list[Path] = field(default_factory=lambda: [])
validation: list[Path] = field(default_factory=lambda: [])
noise: list[Path] = field(default_factory=lambda: [])
training: list[Path] = field(default_factory=lambda: []) # paths to load into the training dataset
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset
temp: list[Path] = field(default_factory=lambda: [])
temp: list[Path] = field(default_factory=lambda: []) # for when I need to yank things out of a training dataset without cutting it out
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'"
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
hdf5_name: str = "data.h5"
use_hdf5: bool = False
hdf5_flag: str = "a"
use_metadata: bool = False
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
use_hdf5: bool = False # whether to load from an HDF5 dataset
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
validate: bool = True
workers: int = 8
cache: bool = True
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
workers: int = 8 # number of dataloader workers to spawn
cache: bool = True # use diskcache to cache the dataset
phones_range: list[int] = field(default_factory=lambda: [4, 256])
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0])
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0])
min_utterances: int = 2
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
min_utterances: int = 2 # minimum number of utterances a speaker can have
random_utterance: float = 1.0
max_prompts: int = 3
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
prompt_duration: float | None = None # legacy
max_resps: int = 1
p_resp_append: float = 1.0
max_resps: int = 1 # number of samples to target for training
p_resp_append: float = 1.0 # probability to append another sample to the training target
sample_type: str = "path" # path | speaker
sample_order: str = "interleaved" # duration
@ -163,11 +163,11 @@ class Dataset:
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
sample_shuffle: bool = True #
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
noise_scale: float = 0.25 # scaling noise value
inject_noise_in_prom: bool = False
inject_noise_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
_frames_per_second: int = 0 # allows setting your own hint
@ -369,7 +369,7 @@ class LoRA:
alpha: int = 128 # rank for the LoRA
training: bool = True #
embeddings: bool = False # train the embedding too
parametrize: bool = False #
parametrize: bool = False # whether to use the parameterized pathway for LoRAs or not
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
@property
@ -385,42 +385,42 @@ class LoRA:
@dataclass()
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
gradient_clipping: int | float = 100
batch_size: int = 8 # number of samples per training batch
gradient_accumulation_steps: int = 32 # number of steps to accumulate gradients before updating
gradient_clipping: int | float = 10 # largest size a gradient norm can be
optimizer: str = "Adamw" # should be 'Prodigyopt" now
optimizer: str = "Adamw" # optimizer to use, should be 'Prodigyopt" now
optimizer_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
learning_rate: float = 3.25e-4 # should be 1.0 for ProdigyOpt
warmup_steps: int = 0
warmup_steps: int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
scheduler: str = ""
scheduler: str = "" # scheduler to use, currently don't ever use one so this doesn't really matter
scheduler_type: str = "" # deprecated
scheduler_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
autotune: bool = False
autotune: bool = False # to do deepspeed's autotuning
autotune_params: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
torch_optimizer: bool = False
torch_scheduler: bool = False
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
@dataclass()
class Evaluation:
batch_size: int = 64
frequency: int = 250
size: int = 64
steps: int = 500
ar_temperature: float = 1.0
nar_temperature: float = 0.0
nar_levels: int = 0
batch_size: int = 64 # number of samples per batch during eval / val
frequency: int = 250 # do eval / val every X iterations
size: int = 64 # number of samples to generate during eval / val
load_disabled_engines: bool = True
steps: int = 500
ar_temperature: float = 1.0 # AR temp for inferencing
nar_temperature: float = 0.0 # NAR temp for inferencing
nar_levels: int = 0 # maximum NAR levels to use for inferencing
load_disabled_engines: bool = True # see the other load_disabled_engines
@dataclass()
class DeepSpeed:
zero_optimization_level: int = 0
zero_optimization_level: int = 0 # doesn't seem to work
use_compression_training: bool = False # cope
compression_bits: int = 8 # cope
inferencing: bool = False # for using DeepSpeed's inferencing wrapper instead
@ -576,45 +576,46 @@ class DeepSpeed:
@dataclass()
class Trainer:
iterations: int = 100_000
iterations: int = 1_000_000 # maximum iterations to train
save_tag: str = "step"
load_tag: str | None = None
save_tag: str = "step" # name to save checkpoints under, "step" will save as current step count
load_tag: str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
save_on_oom: bool = True
save_on_quit: bool = True
save_on_oom: bool = True # save if an OOM error is raised
save_on_quit: bool = True # save when quitting training
export_on_save: bool = False
export_on_quit: bool = False
export_on_save: bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
export_on_quit: bool = False # export weights to local `fp32.pth` state_dict on quitting training
save_frequency: int = 100
save_frequency: int = 100 # frequency to save every X iterations
keep_last_checkpoints: int = 0
keep_last_checkpoints: int = 0 # number of checkpoints to keep, prunes oldest ones
load_state_dict: bool = False
load_states: bool = True
strict_loading: bool = False
load_module_only: bool = False
restart_step_count: bool = False
load_state_dict: bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
load_states: bool = True #
strict_loading: bool = False # sets strict_loading=True when loading the state dict
load_module_only: bool = False #
restart_step_count: bool = False # clears the training stats when loading a checkpoint
resize_modules: bool = False # automatically resizes
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
aggressive_optimizations: bool = False
check_for_oom: bool = True
gc_mode: str | None = None
load_disabled_engines: bool = False
aggressive_optimizations: bool = False # deprecated
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
weight_dtype: str = "float16"
weight_dtype: str = "float16" # dtype to have the model under
amp: bool = False
ddp: bool = False
amp: bool = False # automatic mixed precision
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
load_webui: bool = False
no_logger: bool = False
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
backend: str = "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
@cached_property
def dtype(self):
@ -638,9 +639,9 @@ class Trainer:
@dataclass()
class Inference:
backend: str = "local"
weight_dtype: str = "float32"
amp: bool = False
backend: str = "local" # backend to use when inferencing
weight_dtype: str = "float32" # dtype to load the model under
amp: bool = False # automatic mixed precision during inferencing
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
@ -681,7 +682,7 @@ class Optimizations:
@dataclass()
class Config(BaseConfig):
device: str = "cuda"
device: str = "cuda" # target device
mode: str = "training" # "inferencing"
experimental: bool = False # Debug flag, unused now
@ -695,13 +696,13 @@ class Config(BaseConfig):
bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
tokenizer: str | None = None
tokenizer_path: str = "./tokenizer.json"
tokenizer: str | None = None # tokenizer class
tokenizer_path: str = "./tokenizer.json" # tokenizer path
sample_rate: int = 24_000
sample_rate: int = 24_000 # sample rate the model expects
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
audio_backend: str = "vocos"
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
@property
def model(self):

View File

@ -178,11 +178,19 @@ def load_engines(training=True):
for k in erase:
del state[k]
# resize embeddings
if "text_emb.weight" in state:
state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens )
if "rvq_l_emb.weight" in state:
state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels )
# resize modules if I'm doing experiments and can't be assed to manually trim things
if cfg.trainer.resize_modules:
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
keys = [
("text_emb.weight", model.config.text_tokens ),
("rvq_l_emb.weight", model.config.resp_levels ),
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ),
("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ),
]
for k, tokens in keys:
state[k] = ml.resize_weight( state[k], tokens )
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
@ -227,4 +235,4 @@ def load_engines(training=True):
for name, engine in engines.items():
engine.freeze(freeze_all=False)
return engines
return engines

View File

@ -250,7 +250,13 @@ class AudioClassifier(nn.Module):
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
# pad if needed
xi = [ x if l == 0 else torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ], dim=-1 ) for x, l in zip(xi, levels) ]
max_size = max([ x.shape[-1] for x in xi ])
xi = [
#x if l == 0 else
x if x.shape[-1] == max_size else
torch.cat( [ x, torch.Tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
for x, l in zip(xi, levels)
]
return torch.stack( xi )
class Metrics(nn.Module):
@ -1060,7 +1066,7 @@ class Base(nn.Module):
# shamelessly grabbed from modeling_llama.py
ids = mask.long().cumsum(-1) - 1
ids.masked_fill_( mask == 0, 1 )
# there's a better way
if not self.unified_position_ids:
x_list = []
@ -1074,7 +1080,7 @@ class Base(nn.Module):
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
return input.shape[0] + (0 if name == "resp" else 1)
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
for batch_index, batch_input in enumerate(inputs):
batch = torch.cat( [

View File

@ -97,6 +97,8 @@ class NAR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | None = None,
max_steps: int = 1000,
max_levels: int = 0,
max_resp_context: int = -1,