2023-08-02 21:53:35 +00:00
import copy
import diskcache
import h5py
import json
import os
import subprocess
import sys
import time
2024-06-08 01:23:53 +00:00
import argparse
import yaml
2024-06-25 18:41:29 +00:00
import random
2024-08-29 18:27:16 +00:00
import logging
2024-09-19 02:34:43 +00:00
import itertools
2023-08-02 21:53:35 +00:00
2023-08-14 03:07:45 +00:00
import torch
2024-06-25 18:41:29 +00:00
import numpy as np
2023-08-14 03:07:45 +00:00
2024-04-17 02:04:48 +00:00
from dataclasses import asdict , dataclass , field
2023-08-02 21:53:35 +00:00
from functools import cached_property
from pathlib import Path
2023-08-14 03:56:28 +00:00
from . utils . distributed import world_size
2024-10-26 03:15:15 +00:00
from . utils . io import torch_load
2024-11-11 22:32:08 +00:00
from . utils import set_seed , prune_missing , md5_hash
2024-06-25 18:41:29 +00:00
2023-08-02 21:53:35 +00:00
@dataclass ( )
2024-06-08 01:23:53 +00:00
class BaseConfig :
2024-08-01 01:35:09 +00:00
yaml_path : str | None = None # path passed in through --yaml
2023-08-02 21:53:35 +00:00
@property
2024-06-09 16:22:52 +00:00
def cfg_path ( self ) :
2024-10-26 03:15:15 +00:00
if self . yaml_path :
return Path ( self . yaml_path . parent )
return Path ( __file__ ) . parent . parent / " data "
2024-06-09 16:22:52 +00:00
@property
def rel_path ( self ) :
2023-08-02 21:53:35 +00:00
return Path ( self . cfg_path )
2023-08-27 00:53:23 +00:00
@property
def cache_dir ( self ) :
2024-06-09 16:22:52 +00:00
return self . rel_path / " .cache "
2023-08-27 00:53:23 +00:00
2024-04-29 03:28:29 +00:00
@property
def data_dir ( self ) :
2024-06-09 16:22:52 +00:00
return self . rel_path / " data "
2024-04-29 03:28:29 +00:00
@property
def metadata_dir ( self ) :
2024-06-09 16:22:52 +00:00
return self . rel_path / " metadata "
2024-04-29 03:28:29 +00:00
2023-08-02 21:53:35 +00:00
@property
def ckpt_dir ( self ) :
2024-06-09 16:22:52 +00:00
return self . rel_path / " ckpt "
2023-08-02 21:53:35 +00:00
@property
def log_dir ( self ) :
2024-06-09 16:22:52 +00:00
return self . rel_path / " logs " / str ( self . start_time )
2023-08-02 21:53:35 +00:00
@cached_property
def start_time ( self ) :
return int ( time . time ( ) )
@cached_property
def git_commit ( self ) :
try :
cmd = " git rev-parse HEAD "
return subprocess . check_output ( cmd . split ( ) ) . decode ( " utf8 " ) . strip ( )
except :
return " "
@cached_property
def git_status ( self ) :
try :
cmd = " git status "
return subprocess . check_output ( cmd . split ( ) ) . decode ( " utf8 " ) . strip ( )
except :
return " "
def dumps ( self ) :
data = { k : getattr ( self , k ) for k in dir ( self ) if not k . startswith ( " __ " ) }
data = { k : v for k , v in data . items ( ) if not callable ( v ) }
return json . dumps ( data , indent = 2 , default = str )
def dump ( self , path = None ) :
if path is None :
path = self . log_dir / " cfg.json "
path . parent . mkdir ( parents = True , exist_ok = True )
with open ( path , " w " ) as f :
f . write ( self . dumps ( ) )
2024-10-17 22:06:48 +00:00
# ick
@classmethod
def prune_missing ( cls , yaml ) :
default = cls ( * * { } )
default . format ( )
yaml , missing = prune_missing ( source = default , dest = yaml )
if missing :
_logger . warning ( f ' Missing keys in YAML: { missing } ' )
return yaml
2023-08-02 21:53:35 +00:00
@classmethod
def from_yaml ( cls , yaml_path ) :
2024-06-09 22:11:38 +00:00
state = { }
state = yaml . safe_load ( open ( yaml_path , " r " , encoding = " utf-8 " ) )
state . setdefault ( " yaml_path " , yaml_path )
2024-10-17 22:06:48 +00:00
state = cls . prune_missing ( state )
2024-06-09 22:11:38 +00:00
return cls ( * * state )
2023-08-02 21:53:35 +00:00
2024-10-26 03:15:15 +00:00
@classmethod
2024-10-26 05:13:10 +00:00
def from_model ( cls , model_path , lora_path = None ) :
2024-10-26 03:15:15 +00:00
if not model_path . exists ( ) :
raise Exception ( f ' Model path does not exist: { model_path } ' )
# load state dict and copy its stored model config
2024-11-04 00:31:28 +00:00
model_state_dict = [ torch_load ( model_path ) [ " config " ] | { " path " : model_path , " attention " : " auto " } ] if model_path and model_path . exists ( ) else [ ]
2024-10-26 05:13:10 +00:00
lora_state_dict = [ torch_load ( lora_path ) [ " config " ] | { " path " : lora_path } ] if lora_path and lora_path . exists ( ) else [ ]
2024-10-26 03:15:15 +00:00
2024-10-26 05:13:10 +00:00
state = { " models " : model_state_dict , " loras " : lora_state_dict , " trainer " : { " load_state_dict " : True } }
2024-10-26 03:15:15 +00:00
return cls ( * * state )
2023-08-02 21:53:35 +00:00
@classmethod
def from_cli ( cls , args = sys . argv ) :
2024-06-08 01:23:53 +00:00
# legacy support for yaml=`` format
for i , arg in enumerate ( args ) :
if arg . startswith ( " yaml " ) :
args [ i ] = f ' -- { arg } '
2023-08-02 21:53:35 +00:00
2024-10-18 18:19:36 +00:00
parser = argparse . ArgumentParser ( allow_abbrev = False , add_help = False )
2024-06-08 01:23:53 +00:00
parser . add_argument ( " --yaml " , type = Path , default = os . environ . get ( ' VALLE_YAML ' , None ) ) # os environ so it can be specified in a HuggingFace Space too
2024-10-26 03:15:15 +00:00
parser . add_argument ( " --model " , type = Path , default = os . environ . get ( ' VALLE_MODEL ' , None ) ) # os environ so it can be specified in a HuggingFace Space too
2024-10-26 05:13:10 +00:00
parser . add_argument ( " --lora " , type = Path , default = os . environ . get ( ' VALLE_LORA ' , None ) ) # os environ so it can be specified in a HuggingFace Space too
2024-06-08 01:23:53 +00:00
args , unknown = parser . parse_known_args ( args = args )
2023-08-02 21:53:35 +00:00
2024-10-26 03:15:15 +00:00
if args . model :
2024-10-26 05:13:10 +00:00
return cls . from_model ( args . model , args . lora )
2024-10-26 03:15:15 +00:00
2024-06-08 01:23:53 +00:00
if args . yaml :
2024-10-18 18:19:36 +00:00
return cls . from_yaml ( args . yaml )
2023-08-02 21:53:35 +00:00
2024-06-09 22:11:38 +00:00
return cls ( * * { } )
2023-08-02 21:53:35 +00:00
def __repr__ ( self ) :
return str ( self )
def __str__ ( self ) :
return self . dumps ( )
@dataclass ( )
class Dataset :
2024-08-01 01:35:09 +00:00
training : list [ Path ] = field ( default_factory = lambda : [ ] ) # paths to load into the training dataset
validation : list [ Path ] = field ( default_factory = lambda : [ ] ) # paths to load into the validation dataset
noise : list [ Path ] = field ( default_factory = lambda : [ ] ) # paths to load into the noise dataset
2023-08-02 21:53:35 +00:00
2024-10-17 22:06:48 +00:00
# to-do: replace these since I feel this can be a bottleneck
2024-08-01 01:35:09 +00:00
speaker_name_getter : str = " lambda p: f ' {p.parts[-3]} _ {p.parts[-2]} ' " # function eval'd to extract a speaker's name from an utternace path
speaker_group_getter : str = " lambda p: f ' {p.parts[-3]} ' " # function eval'd to extract a speaker's group from an utternace path
2024-10-17 22:06:48 +00:00
# to-do: validate if I can ignore this since this is an artifact from when I only saved phonemes and encoded audio, and no metadata
2023-10-12 01:38:40 +00:00
speaker_languages : dict = field ( default_factory = lambda : { } ) # dict where keys are the language codes and values are the speaker groups
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
use_hdf5 : bool = False # whether to load from an HDF5 dataset
2024-10-17 22:06:48 +00:00
hdf5_name : str = " data.h5 " # file name to load the HDF5 dataset
2024-08-01 01:35:09 +00:00
hdf5_flag : str = " a " # flag to load the HDF5 file, automatically adjusted anyways
2024-10-17 22:06:48 +00:00
2024-08-01 01:35:09 +00:00
use_metadata : bool = False # use genretaed metadata to aid in dataset loading
2024-07-18 21:48:41 +00:00
2024-08-01 01:35:09 +00:00
validate : bool = True # validate each utterance on wheter it can be included based on duration range caps
workers : int = 8 # number of dataloader workers to spawn
cache : bool = True # use diskcache to cache the dataset
2023-08-02 21:53:35 +00:00
2024-10-17 22:06:48 +00:00
min_utterances : int = 2 # minimum number of utterances a speaker can have
2024-11-18 18:46:50 +00:00
max_utterances : int = 0 # max number of utterances a speaker can have (0 to disable)
2024-08-01 01:35:09 +00:00
duration_range : list [ float ] = field ( default_factory = lambda : [ 1.0 , 12.0 ] ) # the duration range an utterance can be to be included in the dataset
2024-09-17 04:10:29 +00:00
2023-08-17 00:39:21 +00:00
sample_type : str = " path " # path | speaker
2024-06-30 16:36:46 +00:00
sample_order : str = " interleaved " # duration
2024-10-17 22:06:48 +00:00
sample_shuffle : bool = True # shuffles the indices in the sampler
2024-06-29 03:28:54 +00:00
sample_max_duration_batch : float = 0.0 # total number of seconds of utterances per batched, 0 to disable
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
2024-08-29 18:27:16 +00:00
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
2024-05-25 16:07:52 +00:00
2024-10-17 22:06:48 +00:00
prompt_duration_range : list [ float ] = field ( default_factory = lambda : [ 3.0 , 6.0 ] ) # the duration range the input prompts can be
prompt_max_samples : int = 3 # maximum number of utterances that can be included in an input prompt for training
prompt_continuous_utterance_p : float = 0.0 # probability to use the target utterance as an input prompt rather than using a different utterance
prompt_similar_p : float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
prompt_similar_top_k : int = 1 # top-k similar candidates to sample from
prompt_similar_top_k_offset : int = 0 # offset from the top-k to sample from
2024-10-18 14:40:06 +00:00
prompt_inject_noise : bool = False # adds noise to the input prompt waveform to try and vary things
2024-10-17 22:06:48 +00:00
resps_max_samples : int = 1 # number of samples to target for training
resps_append_p : float = 1.0 # probability to append another sample to the training target
resps_pad_silence_p : float = 0.0 # probability to pad resp with silence to fit within the next window
2024-09-17 04:10:29 +00:00
2024-08-01 01:35:09 +00:00
tasks_list : list [ str ] = field ( default_factory = lambda : [ " tts " ] ) # list of tasks to train against
2024-07-18 21:48:41 +00:00
reencode_on_concat : bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
2024-07-18 22:16:32 +00:00
reencode_device : str = " cpu " # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
noise_scale : float = 0.25 # scaling noise value
2024-09-21 18:08:01 +00:00
retokenize_text : bool = False
2024-05-05 04:49:15 +00:00
_frames_per_second : int = 0 # allows setting your own hint
2024-05-04 17:05:41 +00:00
2024-11-11 22:32:08 +00:00
def hash_key ( self , * args ) :
return md5_hash ( [ self . use_hdf5 , self . min_duration , self . max_duration ] + [ * args ] )
2024-05-04 17:05:41 +00:00
@cached_property
def frames_per_second ( self ) :
if self . _frames_per_second > 0 :
return self . _frames_per_second
2024-05-05 04:49:15 +00:00
2024-05-25 16:07:52 +00:00
if cfg . audio_backend == " dac " :
2024-08-07 01:23:33 +00:00
if cfg . sample_rate == 44_100 :
return 87
2024-05-05 04:49:15 +00:00
if cfg . sample_rate == 16_000 :
return 50
# 24Khz Encodec / Vocos and incidentally DAC are all at 75Hz
return 75
2023-08-17 00:39:21 +00:00
2023-08-27 00:53:23 +00:00
@property
def min_phones ( self ) :
return self . phones_range [ 0 ]
@property
def max_phones ( self ) :
return self . phones_range [ 1 ]
@property
def min_duration ( self ) :
return self . duration_range [ 0 ]
@property
def max_duration ( self ) :
return self . duration_range [ 1 ]
2024-07-19 20:33:31 +00:00
# collection of experimental variables that should not be tampered with unless you know what you're doing
2024-06-30 15:37:33 +00:00
@dataclass ( )
class ModelExperimentalSettings :
hf : bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
interleave : bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
2024-07-19 20:33:31 +00:00
split_classifiers : bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
2024-06-30 15:37:33 +00:00
audio_embedding_sums : bool = False # whether each pass uses the previous RVQ codes or only the current level
2024-09-08 13:30:30 +00:00
# a model trained not summing audio embeddings *can* have this enabled without any apparent issues
# a model trained to sum *cannot* have this disabled without any apparent issues, or at least the ar+nar-retnet-8 can't.
# in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so
2024-06-30 16:00:12 +00:00
audio_embedding_mode : str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
2024-06-30 15:37:33 +00:00
kv_heads : int = 0 # MHA or GQA (for supported backends)
2024-10-17 22:06:48 +00:00
rvq_levels_p : str | list = " auto " # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
2024-07-19 20:33:31 +00:00
rvq_level_range : list = field ( default_factory = lambda : [ ] ) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
2024-07-17 00:52:41 +00:00
unified_position_ids : bool = True # False will generate position IDs partitioned for each section
2024-07-19 20:33:31 +00:00
tie_classifier_to_embedding : bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
2024-07-25 00:35:17 +00:00
# performs token dropout to compensate for errors
token_dropout_error : float = 0.0 # probability to nudge a token by ±1
token_dropout_rate : float = 0.0 # probability to randomly set a token to a special dropout value
token_dropout_rvq_levels : list = field ( default_factory = lambda : [ 1 , 8 ] ) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
2024-06-30 15:37:33 +00:00
2024-07-31 01:53:51 +00:00
causal_size : int = 1 # experimental setting to see if I can just do parallel decoding in chunks instead of one-at-a-time without resorting to exotic solutions
# VALL-E 2's approach of "combining token embeddings to group them" sounds terribad for a shared AR/NAR model
# however, introducing partial parallel decoding for the AR maybe maybe MAYBE might help try and unify the AR/NAR tasks better, MAYBE
2024-08-04 00:51:00 +00:00
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
2024-09-08 13:30:30 +00:00
# RetNet's chunked inferencing might be a better place for this
2024-08-04 00:51:00 +00:00
2024-11-10 18:19:48 +00:00
masking_train_p : float = 0.0 # odds of training with masking
masking_train_rvq_levels : list = field ( default_factory = lambda : [ 0 , 0 ] ) # determines which levels to do mask training on
2024-11-15 04:17:47 +00:00
2024-11-20 00:51:17 +00:00
masking_ratio : str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8]
2024-11-16 21:49:06 +00:00
ignore_inputs_for_loss : bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
2024-10-31 18:24:48 +00:00
2024-11-20 00:51:17 +00:00
# classifier-free guidance training settings
2024-11-10 18:19:48 +00:00
cfg_cond_dropout_p : float = 0.0 # 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p : float = 0.0 # 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p : float = 0.0 # 0.3 # probability to drop out input audio prompt during training
2024-11-10 00:04:59 +00:00
2024-11-20 00:51:17 +00:00
# failed experiment
2024-10-31 18:24:48 +00:00
layerskip : bool = False # layerskip compatible model (or training for)
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
2024-11-01 22:06:07 +00:00
layerskip_r : int = 2 # number of layers to factor into early-exit loss calc
2024-10-31 18:24:48 +00:00
layerskip_p_max : float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
layerskip_e_scale : float = 0.2 # early-exit loss scalar value
2024-07-31 01:53:51 +00:00
2024-06-04 02:28:49 +00:00
# I really need to clean this up
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Model :
2024-07-16 00:59:48 +00:00
name : str = " ar+nar " # vanity name for the model
2024-06-29 03:28:54 +00:00
version : int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
2023-10-02 21:52:42 +00:00
size : str | dict = " full " # preset string or explicitly defined dimensionality
2024-07-16 00:59:48 +00:00
resp_levels : int = 8 # RVQ-bin levels this model supports
2024-06-29 03:28:54 +00:00
tasks : int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
langs : int = 1 # defined languages (semi-unused)
tones : int = 1 # defined tones (unsued)
experts : int = 1 # for mixtral / retnet-ts
arch_type : str = " llama " # underling LM architecture used
training : bool = True # I really need to attend to this
2023-10-02 21:52:42 +00:00
frozen_params : list [ str ] = field ( default_factory = lambda : [ ] ) # frozen parameters that are not updated when training
2024-06-29 03:28:54 +00:00
attention : str = " auto " # for llama arch_types: attention used
2024-05-11 21:31:05 +00:00
dropout : float = 0.1 # adjustable dropout value
2024-10-26 05:13:10 +00:00
path : Path | None = None
2024-06-08 01:29:25 +00:00
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
2024-06-05 04:48:51 +00:00
loss_factors : dict = field ( default_factory = lambda : { } )
2024-07-19 20:33:31 +00:00
capabilities : list = field ( default_factory = lambda : [ " ar " , " nar " ] ) # + ["lang", "tone"] if you have your dataset labeled for such
2024-11-12 02:21:16 +00:00
kwargs : dict = field ( default_factory = lambda : { } )
2024-06-29 03:28:54 +00:00
2024-06-30 15:37:33 +00:00
experimental : dict | ModelExperimentalSettings | None = None # experimental settings
2023-08-02 21:53:35 +00:00
2024-04-16 00:54:32 +00:00
def get ( self , name = None ) :
return [ self ] if not name or self . name == name else [ ]
2024-05-19 16:23:56 +00:00
def loss_factor ( self , k ) :
2024-11-15 04:17:47 +00:00
return self . loss_factors . get ( k , 0.0 )
2024-04-16 00:54:32 +00:00
@property
def max_levels ( self ) :
2024-07-16 00:59:48 +00:00
# return RVQ level range
if self . experimental is not None and self . experimental . rvq_level_range :
return self . experimental . rvq_level_range [ - 1 ]
return self . resp_levels
2024-04-16 00:54:32 +00:00
2024-04-09 01:14:51 +00:00
@property
# required for fp8 as the lengths needs to be divisible by 8
def input_alignment ( self ) :
2024-05-03 01:08:59 +00:00
return 8 if cfg . optimizations . fp8 else 0
2024-04-09 01:14:51 +00:00
2023-08-02 21:53:35 +00:00
@property
def full_name ( self ) :
name = [ self . name ]
2024-04-13 17:43:35 +00:00
if isinstance ( self . size , dict ) :
if hasattr ( self . size , " label " ) and self . size [ ' label ' ] :
name . append ( f " { self . size [ ' label ' ] } " )
elif isinstance ( self . size , str ) and self . size != " full " :
2023-08-02 21:53:35 +00:00
name . append ( self . size )
2024-06-06 01:53:10 +00:00
if self . experts > 1 :
name . append ( f ' { self . experts } x ' + self . arch_type . replace ( " / " , " - " ) )
else :
name . append ( self . arch_type . replace ( " / " , " - " ) )
2023-08-02 21:53:35 +00:00
2024-05-03 01:08:59 +00:00
if cfg . optimizations . bitnet :
2024-03-01 16:32:35 +00:00
name . append ( " bitnet " )
2024-06-30 16:11:58 +00:00
name . append ( f ' { self . resp_levels } ' )
2023-08-02 21:53:35 +00:00
return " - " . join ( name )
@property
def tokens ( self ) :
2024-06-06 01:53:10 +00:00
return self . audio_tokens
2023-09-01 22:19:34 +00:00
2024-06-06 01:53:10 +00:00
@property
def audio_tokens ( self ) :
if isinstance ( self . size , dict ) and hasattr ( self . size , " audio_tokens " ) :
return self . size [ ' audio_tokens ' ]
2023-08-02 21:53:35 +00:00
return 1024
2024-06-06 01:53:10 +00:00
@property
def text_tokens ( self ) :
if isinstance ( self . size , dict ) and hasattr ( self . size , " text_tokens " ) :
return self . size [ ' text_tokens ' ]
return 256
2023-08-02 21:53:35 +00:00
@property
def dim ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " dim " ) :
return self . size [ ' dim ' ]
if isinstance ( self . size , float ) :
return math . floor ( 1024 * self . size )
2023-08-02 21:53:35 +00:00
if self . size == " quarter " :
return 256
if self . size == " half " :
return 512
2023-09-02 02:33:51 +00:00
return 1024
2023-08-02 21:53:35 +00:00
@property
def heads ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " heads " ) :
return self . size [ ' heads ' ]
if isinstance ( self . size , float ) :
return math . floor ( 16 * self . size )
2023-08-02 21:53:35 +00:00
if self . size == " quarter " :
return 4
if self . size == " half " :
return 8
2023-09-02 02:33:51 +00:00
return 16
2023-08-02 21:53:35 +00:00
@property
def layers ( self ) :
2023-09-01 22:19:34 +00:00
if isinstance ( self . size , dict ) and hasattr ( self . size , " layers " ) :
return self . size [ ' layers ' ]
2023-09-02 02:33:51 +00:00
if self . size == " double " :
return 24
2023-08-02 21:53:35 +00:00
return 12
2023-09-05 20:38:21 +00:00
@property
def activation_checkpointing ( self ) :
return cfg . trainer . activation_checkpointing
2024-06-04 02:28:49 +00:00
@property
def gradient_checkpointing ( self ) :
return cfg . trainer . gradient_checkpointing
2024-06-17 05:09:16 +00:00
2024-06-17 18:05:06 +00:00
@property
def lora_policy ( self ) :
include = [ " model " ] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
exclude = [ ]
if self . arch_type == " llama " :
include = [ " self_attn " , " mlp " ] # target only the attention + mlp
exclude = [ " self_attn.k_proj " ] # common literature says to ignore it
2024-07-16 23:23:13 +00:00
if self . arch_type == " retnet " :
include = [ " layers. " ] # target the core layers of the RetNet and ignore the auxiliary stuff
exclude = [ " retention.k_proj " ] # attention-based transformers ignore the K, so might as well ignore it for the retnet
2024-06-17 18:05:06 +00:00
return dict ( include = include , exclude = exclude )
2024-11-12 02:21:16 +00:00
# to-do: derive default arguments from here
@property
def get_kwargs ( self , type ) :
return self . kwargs
2024-08-29 18:27:16 +00:00
# should be renamed to Adapters
2024-06-17 05:09:16 +00:00
@dataclass ( )
class LoRA :
name : str = " lora " # vanity name
2024-06-17 18:05:06 +00:00
# to-do: find sane default values
2024-06-29 03:28:54 +00:00
rank : int = 128 # rank for the LoRA
alpha : int = 128 # rank for the LoRA
2024-06-17 05:09:16 +00:00
training : bool = True #
2024-07-23 00:36:07 +00:00
embeddings : bool = False # train the embedding too
2024-08-01 01:35:09 +00:00
parametrize : bool = False # whether to use the parameterized pathway for LoRAs or not
2024-06-18 02:45:03 +00:00
rvq_levels : list [ int ] = field ( default_factory = lambda : [ ] ) # determines RVQ levels to activate the LoRA
2024-10-26 05:13:10 +00:00
path : Path | None = None
2024-06-17 05:09:16 +00:00
@property
def full_name ( self ) :
name = [ self . name , f " r { self . rank } " , f " a { self . alpha } " ]
return " - " . join ( name )
2024-06-17 18:05:06 +00:00
2024-06-29 03:28:54 +00:00
# actually not needed anymore
2024-06-18 02:45:03 +00:00
def active_level ( self , level ) :
if not self . rvq_levels :
return True
return level in self . rvq_levels
2024-06-04 02:28:49 +00:00
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Hyperparameters :
2024-08-01 01:35:09 +00:00
batch_size : int = 8 # number of samples per training batch
gradient_accumulation_steps : int = 32 # number of steps to accumulate gradients before updating
2024-11-06 19:51:28 +00:00
gradient_clipping : int | float = 1.0 # largest size a gradient norm can be
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
optimizer : str = " Adamw " # optimizer to use, should be 'Prodigyopt" now
2024-05-10 02:00:26 +00:00
optimizer_params : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2024-06-29 03:28:54 +00:00
learning_rate : float = 3.25e-4 # should be 1.0 for ProdigyOpt
2024-08-01 01:35:09 +00:00
warmup_steps : int = 0 # number of steps to warm up the optimizer before performing updates, I think, this is just passed to deepspeed
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
scheduler : str = " " # scheduler to use, currently don't ever use one so this doesn't really matter
2024-05-10 01:28:20 +00:00
scheduler_type : str = " " # deprecated
2024-05-10 02:00:26 +00:00
scheduler_params : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2024-05-10 02:25:40 +00:00
2024-08-01 01:35:09 +00:00
autotune : bool = False # to do deepspeed's autotuning
2024-05-10 02:25:40 +00:00
autotune_params : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2024-05-10 02:00:26 +00:00
2024-08-01 01:35:09 +00:00
torch_optimizer : bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
torch_scheduler : bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
2023-08-02 21:53:35 +00:00
@dataclass ( )
class Evaluation :
2024-08-01 01:35:09 +00:00
batch_size : int = 64 # number of samples per batch during eval / val
frequency : int = 250 # do eval / val every X iterations
size : int = 64 # number of samples to generate during eval / val
2024-10-26 01:38:09 +00:00
kwargs : dict = field ( default_factory = lambda : { } ) # inferencing kwargs
# necessary in order to make it not confusing with requiring not-directyl exposed arguments passed to the model
@cached_property
def ar_kwargs ( self ) :
return dict (
2024-11-12 02:21:16 +00:00
max_steps = self . kwargs . get ( " max_ar_steps " , 500 ) ,
temperature = self . kwargs . get ( " ar_temperature " , 1.0 ) ,
min_temperature = self . kwargs . get ( " min_ar_temperature " , - 1 ) ,
top_p = self . kwargs . get ( " top_p " , 1.0 ) , top_k = self . kwargs . get ( " top_k " , 0 ) , min_p = self . kwargs . get ( " min_p " , 0.0 ) ,
repetition_penalty = self . kwargs . get ( " repetition_penalty " , 1.0 ) , repetition_penalty_decay = self . kwargs . get ( " repetition_penalty_decay " , 0 ) ,
length_penalty = self . kwargs . get ( " length_penalty " , 0 ) ,
beam_width = self . kwargs . get ( " beam_width " , 0 ) ,
mirostat_tau = self . kwargs . get ( " mirostat_tau " , 0 ) ,
mirostat_eta = self . kwargs . get ( " mirostat_eta " , 0 ) ,
dry_multiplier = self . kwargs . get ( " dry_multiplier " , 0 ) ,
dry_base = self . kwargs . get ( " dry_base " , 0 ) ,
dry_allowed_length = self . kwargs . get ( " dry_allowed_length " , 0 ) ,
entropix = self . kwargs . get ( " entropix_sampling " , False ) ,
2024-10-26 01:38:09 +00:00
)
@cached_property
def nar_kwargs ( self ) :
return dict (
2024-11-12 02:21:16 +00:00
max_levels = self . kwargs . get ( " max_nar_levels " , 0 ) ,
temperature = self . kwargs . get ( " nar_temperature " , 0.0 ) ,
min_temperature = self . kwargs . get ( " min_nar_temp " , - 1 ) ,
top_p = self . kwargs . get ( " top_p " , 1.0 ) , top_k = self . kwargs . get ( " top_k " , 0.0 ) , min_p = self . kwargs . get ( " min_p " , 0.0 ) ,
repetition_penalty = self . kwargs . get ( " repetition_penalty " , 1.0 ) , repetition_penalty_decay = self . kwargs . get ( " repetition_penalty_decay " , 0.0 ) ,
2024-10-26 01:38:09 +00:00
)
2023-08-04 01:26:36 +00:00
@dataclass ( )
class DeepSpeed :
2024-08-01 01:35:09 +00:00
zero_optimization_level : int = 0 # doesn't seem to work
2024-06-29 03:28:54 +00:00
use_compression_training : bool = False # cope
compression_bits : int = 8 # cope
inferencing : bool = False # for using DeepSpeed's inferencing wrapper instead
2024-05-11 18:57:43 +00:00
2024-06-29 03:28:54 +00:00
amp : bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
2024-05-11 18:57:43 +00:00
2024-11-15 04:17:47 +00:00
loss_scale_window : int = 100
min_loss_scale : float = 8192.0
2024-05-11 18:57:43 +00:00
config : dict = field ( default_factory = lambda : { } ) # to pass through deepspeed config
2023-08-04 01:26:36 +00:00
2023-08-19 20:06:33 +00:00
@cached_property
def ds_cfg ( self ) :
2024-05-10 02:00:26 +00:00
optimizer_params = cfg . hyperparameters . optimizer_params
if ' lr ' not in optimizer_params :
optimizer_params [ " lr " ] = cfg . hyperparameters . learning_rate ,
scheduler_params = cfg . hyperparameters . scheduler_params
if ' warmup_num_steps ' not in scheduler_params :
scheduler_params [ ' warmup_num_steps ' ] = cfg . hyperparameters . warmup_steps
2023-08-04 01:26:36 +00:00
2024-05-10 02:00:26 +00:00
if ' total_num_steps ' not in scheduler_params :
2023-08-04 01:26:36 +00:00
scheduler_params [ ' total_num_steps ' ] = cfg . trainer . iterations
2024-05-10 02:25:40 +00:00
autotune_params = cfg . hyperparameters . autotune_params
if " enabled " not in autotune_params :
autotune_params [ ' enabled ' ] = True
if " results_dir " not in autotune_params :
2024-06-09 16:22:52 +00:00
autotune_params [ ' results_dir ' ] = str ( cfg . rel_path / " autotune " / " results " )
2024-05-10 02:25:40 +00:00
if " exps_dir " not in autotune_params :
2024-06-09 16:22:52 +00:00
autotune_params [ ' exps_dir ' ] = str ( cfg . rel_path / " autotune " / " exps_ " )
2024-05-10 02:25:40 +00:00
2024-05-10 03:48:42 +00:00
# DeepSpeed fp16 is incompatible with its AMP
2024-05-12 03:58:38 +00:00
if cfg . trainer . weight_dtype . lower ( ) == " float16 " :
2024-05-10 03:48:42 +00:00
self . amp = False
# disable local AMP
if self . amp :
cfg . trainer . amp = False
2023-08-04 01:26:36 +00:00
ds_cfg = {
" train_micro_batch_size_per_gpu " : cfg . hyperparameters . batch_size ,
" gradient_accumulation_steps " : cfg . hyperparameters . gradient_accumulation_steps ,
" optimizer " : {
" type " : cfg . hyperparameters . optimizer ,
2024-05-10 02:00:26 +00:00
" params " : optimizer_params ,
2023-09-07 14:14:03 +00:00
} if not cfg . hyperparameters . torch_optimizer else None ,
2023-08-04 01:26:36 +00:00
" scheduler " : {
2024-05-10 01:28:20 +00:00
" type " : cfg . hyperparameters . scheduler ,
2023-08-04 01:26:36 +00:00
" params " : scheduler_params ,
2024-05-10 01:28:20 +00:00
} if not cfg . hyperparameters . torch_scheduler else None ,
2023-08-04 01:26:36 +00:00
" gradient_clipping " : cfg . hyperparameters . gradient_clipping ,
" fp16 " : {
2024-05-12 03:58:38 +00:00
" enabled " : cfg . trainer . weight_dtype . lower ( ) == " float16 " ,
2024-05-10 03:48:42 +00:00
" auto_cast " : True , # ???
2024-11-15 04:17:47 +00:00
" loss_scale_window " : self . loss_scale_window , # raise every 100 consecutive good steps
" min_loss_scale " : self . min_loss_scale , # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
2024-08-09 15:51:36 +00:00
" loss_scale " : 0.0 if cfg . trainer . scale_loss else 1.0 ,
2024-05-11 18:57:43 +00:00
} ,
2023-08-04 01:26:36 +00:00
" bf16 " : {
2024-05-10 02:00:26 +00:00
" enabled " : cfg . trainer . weight_dtype . lower ( ) == " bfloat16 " ,
2023-08-04 01:26:36 +00:00
} ,
2024-05-10 02:00:26 +00:00
" amp " : {
2024-05-10 03:48:42 +00:00
" enabled " : self . amp ,
2024-05-10 02:25:40 +00:00
} ,
" autotuning " : autotune_params if cfg . hyperparameters . autotune else None ,
2023-08-04 01:26:36 +00:00
" compression_training " : {
" weight_quantization " : {
" shared_parameters " : {
" enabled " : True ,
" quantizer_kernel " : True ,
" schedule_offset " : 0 ,
" quantize_groups " : 64 ,
" quantize_verbose " : True ,
" quantization_type " : " symmetric " ,
" rounding " : " nearest " ,
2023-10-13 03:21:43 +00:00
" quantize_weight_in_forward " : cfg . trainer . weight_dtype . lower ( ) != " float16 " , # MoQ (quantize in optimization step) weight quantization is only supported for FP16
2023-08-04 01:26:36 +00:00
" fp16_mixed_quantize " : {
" enabled " : False ,
" quantize_change_ratio " : 1
}
} ,
" different_groups " : {
" wq1 " : {
" params " : {
2023-08-19 01:58:07 +00:00
" start_bits " : self . compression_bits ,
" target_bits " : self . compression_bits ,
2023-08-04 01:26:36 +00:00
" quantization_period " : 0
} ,
2024-05-10 02:00:26 +00:00
" modules " : [ " self_attn " , " mlp " ] # for LLaMA, need to find for other arches
2023-10-13 03:21:43 +00:00
}
}
} ,
" activation_quantization " : {
" shared_parameters " : {
" enabled " : True ,
" quantizer_kernel " : True ,
" schedule_offset " : 0 ,
" quantize_groups " : 64 ,
" quantize_verbose " : True ,
" quantization_type " : " symmetric " ,
" rounding " : " nearest " ,
" quantize_weight_in_forward " : cfg . trainer . weight_dtype . lower ( ) != " float16 " , # MoQ (quantize in optimization step) weight quantization is only supported for FP16
" fp16_mixed_quantize " : {
" enabled " : False ,
" quantize_change_ratio " : 1
}
} ,
" different_groups " : {
" aq1 " : {
" params " : {
" bits " : self . compression_bits ,
} ,
2024-05-10 02:00:26 +00:00
" modules " : [ " self_attn " , " mlp " ] # for LLaMA, need to find for other arches
2023-08-04 01:26:36 +00:00
}
}
} ,
} if self . use_compression_training else None ,
" zero_optimization " : {
" stage " : self . zero_optimization_level ,
" contiguous_gradients " : True ,
" overlap_comm " : True ,
" reduce_scatter " : True ,
" reduce_bucket_size " : 5e8 ,
" allgather_bucket_size " : 5e8 ,
" sub_group_size " : 5e8 ,
" round_robin_gradients " : True ,
" offload_optimizer " : {
" device " : " cpu " ,
" pin_memory " : True
} ,
" offload_param " : {
" device " : " cpu " ,
" pin_memory " : True
2023-08-16 02:58:16 +00:00
} ,
" zero_quantized_weights " : self . use_compression_training ,
" zero_hpz_partition_size " : world_size ( ) ,
" zero_quantized_gradients " : self . use_compression_training ,
2023-08-04 01:26:36 +00:00
} if self . zero_optimization_level > 0 else None ,
" comms_logger " : {
" enabled " : False
}
}
null_keys = [ k for k in ds_cfg if not ds_cfg [ k ] ]
for k in null_keys :
del ds_cfg [ k ]
2023-08-19 20:06:33 +00:00
if os . path . exists ( " ./data/ds_config.json " ) :
2024-09-18 03:57:04 +00:00
ds_cfg . update ( json . loads ( open ( " ./data/ds_config.json " , " r " , encoding = " utf-8 " ) ) . read ( ) )
2024-05-11 18:57:43 +00:00
else :
ds_cfg . update ( self . config )
2023-08-04 01:26:36 +00:00
return ds_cfg
2023-08-02 21:53:35 +00:00
2024-05-19 16:23:56 +00:00
@dataclass ( )
2023-08-02 21:53:35 +00:00
class Trainer :
2024-08-01 01:35:09 +00:00
iterations : int = 1_000_000 # maximum iterations to train
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
save_tag : str = " step " # name to save checkpoints under, "step" will save as current step count
load_tag : str | None = None # tag to load checkpoint from; if None: will check against contents of `./ckpt/{model-name}/latest` for the checkpoint name
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
save_on_oom : bool = True # save if an OOM error is raised
save_on_quit : bool = True # save when quitting training
2023-08-23 21:43:03 +00:00
2024-08-01 01:35:09 +00:00
export_on_save : bool = False # export weights to local `fp32.pth` state_dict on saving a checkpoint
export_on_quit : bool = False # export weights to local `fp32.pth` state_dict on quitting training
2023-08-23 21:43:03 +00:00
2024-08-01 01:35:09 +00:00
save_frequency : int = 100 # frequency to save every X iterations
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
keep_last_checkpoints : int = 0 # number of checkpoints to keep, prunes oldest ones
2023-08-17 01:12:12 +00:00
2024-08-01 01:35:09 +00:00
load_state_dict : bool = False # loads `fp32.pth` state_dict, will automatically be done if a checkpoint is not found but `fp32.pth` exists
load_states : bool = True #
strict_loading : bool = False # sets strict_loading=True when loading the state dict
load_module_only : bool = False #
restart_step_count : bool = False # clears the training stats when loading a checkpoint
resize_modules : bool = False # automatically resizes
2023-08-02 21:53:35 +00:00
2024-06-29 03:28:54 +00:00
activation_checkpointing : bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
2024-08-01 01:35:09 +00:00
gradient_checkpointing : bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
2023-09-05 20:38:21 +00:00
2024-08-01 01:35:09 +00:00
check_for_oom : bool = True # checks for OOMs thrown during forward/backwards
gc_mode : str | None = None # deprecated, but marks when to do GC
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
weight_dtype : str = " float16 " # dtype to have the model under
2024-07-16 23:23:13 +00:00
2024-08-01 01:35:09 +00:00
amp : bool = False # automatic mixed precision
ddp : bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
2024-08-09 16:38:08 +00:00
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
2023-08-02 21:53:35 +00:00
2024-10-17 22:06:48 +00:00
load_webui : bool = False # load the web UI to allow inferencing during training, to-do: actually make this work
2023-10-21 14:55:38 +00:00
2024-08-01 01:35:09 +00:00
backend : str = " local " # training backend to use. currently supports "local" | "deepspeed"
deepspeed : DeepSpeed = field ( default_factory = lambda : DeepSpeed ) # deepspeed settings
2023-08-02 21:53:35 +00:00
2023-08-05 03:22:15 +00:00
@cached_property
def dtype ( self ) :
if self . weight_dtype == " float16 " :
return torch . float16
2023-08-14 03:07:45 +00:00
if self . weight_dtype == " bfloat16 " :
2023-08-05 03:22:15 +00:00
return torch . bfloat16
2024-04-09 01:14:51 +00:00
if self . weight_dtype == " float8_e5m2 " :
return torch . float8_e5m2
if self . weight_dtype == " float8_e4m3fn " :
return torch . float8_e4m3fn
2023-08-05 03:22:15 +00:00
return torch . float32
2024-05-12 03:23:29 +00:00
@cached_property
def scale_loss ( self ) :
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
return self . dtype == torch . float16
2023-08-02 23:36:26 +00:00
@dataclass ( )
class Inference :
2024-08-01 01:35:09 +00:00
backend : str = " local " # backend to use when inferencing
2024-10-26 01:38:09 +00:00
weight_dtype : str = " float16 " # dtype to load the model under
amp : bool = True # automatic mixed precision during inferencing
2023-08-21 02:36:02 +00:00
2024-10-17 22:06:48 +00:00
normalize : bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though
2023-08-02 23:36:26 +00:00
2024-07-16 00:59:48 +00:00
@property
2023-08-21 02:36:02 +00:00
def dtype ( self ) :
if self . weight_dtype == " float16 " :
return torch . float16
if self . weight_dtype == " bfloat16 " :
return torch . bfloat16
2023-09-10 03:27:20 +00:00
if self . weight_dtype == " int8 " :
return torch . int8
2024-04-09 01:14:51 +00:00
if self . weight_dtype == " float8_e5m2 " :
return torch . float8_e5m2
if self . weight_dtype == " float8_e4m3fn " :
return torch . float8_e4m3fn
2023-08-21 02:36:02 +00:00
return torch . float32
2023-08-02 23:36:26 +00:00
@dataclass ( )
2024-05-03 01:08:59 +00:00
class Optimizations :
2024-05-10 01:28:20 +00:00
injects : bool = False # overwrites default torch classes (not recommended)
replace : bool = False # replaces modules in place with the optimized version (recommended)
2024-08-04 03:10:21 +00:00
compile : bool | str = False # runs torch.compile on the model
2023-08-02 23:36:26 +00:00
2024-05-10 01:28:20 +00:00
linear : bool = True # inject/replace linear for BnB
embedding : bool = True # inject/replace embedding for BnB
optimizers : bool = True # inject/replace optimizers (BnB, DAdaptation)
2024-03-01 02:29:17 +00:00
2024-05-10 01:28:20 +00:00
bitsandbytes : bool = False # use bitsandbytes
2024-07-16 00:59:48 +00:00
dadaptation : bool = False # use dadaptation optimizer
2024-05-10 01:28:20 +00:00
bitnet : bool = False # use bitnet
fp8 : bool = False # use fp8
2024-04-09 01:14:51 +00:00
2024-10-17 22:06:48 +00:00
# to-do: validate this madness works still, I don't remember what schizodemon told me to do this
2024-08-02 01:12:06 +00:00
model_offloading : dict | None = None # automatically splits the model over a list of devices
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
2024-08-02 01:56:28 +00:00
# | {"assign": [[ f'layers.{i}.' for i in range(0,6) ], [ f'layers.{i}.' for i in range(6,12) ]]} will assign layers 0-5 to device 1, and 6-12 to device 2
2024-08-02 01:12:06 +00:00
2024-08-04 03:10:21 +00:00
tensorrt : bool = False
2024-10-22 23:12:39 +00:00
unsloth : bool = False # unsloth gradient checkpointing (it just offloads tensors to the CPU during backwards, I don't think it's significant enough to bother with on small models)
2024-08-04 03:10:21 +00:00
2023-08-02 21:53:35 +00:00
@dataclass ( )
2024-06-08 01:23:53 +00:00
class Config ( BaseConfig ) :
2024-08-01 01:35:09 +00:00
device : str = " cuda " # target device
2023-09-02 01:58:29 +00:00
mode : str = " training " # "inferencing"
2024-10-17 22:06:48 +00:00
experimental : bool = False # debug flag
2024-11-10 18:19:48 +00:00
silent_errors : bool = False # if False, raise exceptions on errors that could silently lead to problems, if True ignore them
2023-08-02 21:53:35 +00:00
dataset : Dataset = field ( default_factory = lambda : Dataset )
2024-06-17 05:09:16 +00:00
models : dict | list | None = field ( default_factory = lambda : [ ] )
loras : dict | list | None = field ( default_factory = lambda : [ ] )
2023-08-02 21:53:35 +00:00
hyperparameters : Hyperparameters = field ( default_factory = lambda : Hyperparameters )
evaluation : Evaluation = field ( default_factory = lambda : Evaluation )
trainer : Trainer = field ( default_factory = lambda : Trainer )
2023-08-02 23:36:26 +00:00
inference : Inference = field ( default_factory = lambda : Inference )
2024-05-03 01:08:59 +00:00
optimizations : Optimizations = field ( default_factory = lambda : Optimizations )
2024-04-09 01:14:51 +00:00
2024-08-01 01:35:09 +00:00
tokenizer : str | None = None # tokenizer class
tokenizer_path : str = " ./tokenizer.json " # tokenizer path
2023-08-02 21:53:35 +00:00
2024-08-01 01:35:09 +00:00
sample_rate : int = 24_000 # sample rate the model expects
audio_backend : str = " vocos " # audio backend to use "encodec" | "vocos" | "dac""
2024-05-25 16:07:52 +00:00
2024-11-19 18:24:33 +00:00
weights_name : str = " fp32 "
2024-09-08 02:45:05 +00:00
weights_format : str = " sft " # "pth" | "sft"
2024-08-04 04:34:18 +00:00
supported_weights_formats : list [ str ] = field ( default_factory = lambda : [ " sft " , " safetensors " , " pt " , " pth " ] )
2024-08-07 01:23:33 +00:00
def set_audio_backend ( self , audio_backend ) :
cfg . audio_backend = audio_backend
audio_extension = None
if audio_backend in [ " encodec " , " vocos " ] :
audio_extension = " .enc "
cfg . sample_rate = 24_000
cfg . model . resp_levels = 8
elif audio_backend == " dac " :
audio_extension = " .dac "
cfg . sample_rate = 44_100
cfg . model . resp_levels = 9
elif cfg . audio_backend == " audiodec " :
audio_extension = " .dec "
sample_rate = 48_000
cfg . model . resp_levels = 8 # ?
else :
raise Exception ( f " Unknown audio backend: { audio_backend } " )
@property
def audio_backend_extension ( self ) :
audio_extension = None
if self . audio_backend in [ " encodec " , " vocos " ] :
audio_extension = " .enc "
elif self . audio_backend == " dac " :
audio_extension = " .dac "
elif self . audio_backend == " audiodec " :
audio_extension = " .dec "
return audio_extension
2024-06-06 14:48:43 +00:00
@property
def model ( self ) :
for i , model in enumerate ( self . models ) :
if model . training :
return model
2024-06-17 05:09:16 +00:00
return self . models [ 0 ] if len ( self . models ) > 0 else None
2024-08-29 18:27:16 +00:00
# should be renamed to adapters
2024-06-17 05:09:16 +00:00
@property
def lora ( self ) :
for i , lora in enumerate ( self . loras ) :
if lora . training :
return lora
return self . loras [ 0 ] if len ( self . loras ) > 0 else None
2024-06-06 14:48:43 +00:00
2023-08-14 03:56:28 +00:00
@property
def distributed ( self ) :
return world_size ( ) > 1
2023-08-02 21:53:35 +00:00
@cached_property
def get_spkr ( self ) :
return eval ( self . dataset . speaker_name_getter )
2023-10-12 01:38:40 +00:00
@cached_property
def get_spkr_group ( self ) :
return eval ( self . dataset . speaker_group_getter )
2023-08-02 21:53:35 +00:00
@cached_property
def diskcache ( self ) :
2024-06-09 16:22:52 +00:00
if self . yaml_path is not None and self . dataset . cache :
2023-08-02 21:53:35 +00:00
return diskcache . Cache ( self . cache_dir ) . memoize
return lambda : lambda x : x
2024-10-26 03:15:15 +00:00
# this gets called from vall_e.inference
2023-08-02 21:53:35 +00:00
def load_yaml ( self , config_path ) :
tmp = Config . from_yaml ( config_path )
self . __dict__ . update ( tmp . __dict__ )
2024-10-26 03:15:15 +00:00
2024-10-26 05:13:10 +00:00
def load_model ( self , config_path , lora_path = None ) :
tmp = Config . from_model ( config_path , lora_path )
2024-10-26 03:15:15 +00:00
self . __dict__ . update ( tmp . __dict__ )
2023-08-02 21:53:35 +00:00
2023-08-19 14:50:07 +00:00
def load_hdf5 ( self , write = False ) :
if hasattr ( self , ' hdf5 ' ) :
self . hdf5 . close ( )
if self . distributed :
self . dataset . hdf5_flag = " r "
try :
2024-06-09 16:22:52 +00:00
self . hdf5 = h5py . File ( f ' { self . rel_path } / { self . dataset . hdf5_name } ' , ' a ' if write else self . dataset . hdf5_flag ) # to-do, have an easy to set flag that determines if training or creating the dataset
2023-08-19 14:50:07 +00:00
except Exception as e :
2024-08-29 18:27:16 +00:00
_logger . warning ( f " Error while opening HDF5 file: { self . rel_path } / { self . dataset . hdf5_name } : { str ( e ) } " )
2023-08-19 14:50:07 +00:00
self . dataset . use_hdf5 = False
2024-09-19 02:34:43 +00:00
# a very icky way to handle wildcard expansions
def expand ( self , path ) :
if not isinstance ( path , Path ) :
path = Path ( path )
# do not glob
if " * " not in str ( path ) :
return [ path ]
metadata_parent = cfg . metadata_dir / path . parent
data_parent = cfg . data_dir / path . parent
if metadata_parent . exists ( ) :
return [ path . parent / child . stem for child in Path ( metadata_parent ) . glob ( path . name ) ]
if data_parent . exists ( ) :
return [ path . parent / child . name for child in Path ( data_parent ) . glob ( path . name ) ]
2024-11-10 18:19:48 +00:00
# return an empty list
if self . silent_errors :
return [ ]
# raise an error to avoid headaches
raise Exception ( f ' Cannot unglob requested path: { path } ' )
2024-09-19 02:34:43 +00:00
2024-06-09 22:11:38 +00:00
def format ( self , training = True ) :
2024-05-12 19:01:52 +00:00
if isinstance ( self . dataset , type ) :
self . dataset = dict ( )
2024-06-06 14:48:43 +00:00
if isinstance ( self . models , type ) :
self . models = dict ( )
2024-06-17 05:09:16 +00:00
if isinstance ( self . loras , type ) :
self . loras = dict ( )
2024-05-12 19:01:52 +00:00
if isinstance ( self . hyperparameters , type ) :
self . hyperparameters = dict ( )
if isinstance ( self . evaluation , type ) :
self . evaluation = dict ( )
if isinstance ( self . trainer , type ) :
self . trainer = dict ( )
if isinstance ( self . inference , type ) :
self . inference = dict ( )
if isinstance ( self . optimizations , type ) :
self . optimizations = dict ( )
2024-10-26 03:15:15 +00:00
if isinstance ( self . dataset , dict ) :
self . dataset = Dataset ( * * self . dataset )
if isinstance ( self . hyperparameters , dict ) :
self . hyperparameters = Hyperparameters ( * * self . hyperparameters )
if isinstance ( self . evaluation , dict ) :
self . evaluation = Evaluation ( * * self . evaluation )
if isinstance ( self . trainer , dict ) :
self . trainer = Trainer ( * * self . trainer )
if isinstance ( self . trainer . deepspeed , dict ) :
self . trainer . deepspeed = DeepSpeed ( * * self . trainer . deepspeed )
if isinstance ( self . inference , dict ) :
self . inference = Inference ( * * self . inference )
if isinstance ( self . optimizations , dict ) :
self . optimizations = Optimizations ( * * self . optimizations )
2024-09-19 02:34:43 +00:00
# convert to expanded paths
self . dataset . training = [ self . expand ( dir ) for dir in self . dataset . training ]
self . dataset . validation = [ self . expand ( dir ) for dir in self . dataset . validation ]
self . dataset . noise = [ self . expand ( dir ) for dir in self . dataset . noise ]
# flatten
self . dataset . training = list ( itertools . chain . from_iterable ( self . dataset . training ) )
self . dataset . validation = list ( itertools . chain . from_iterable ( self . dataset . validation ) )
self . dataset . noise = list ( itertools . chain . from_iterable ( self . dataset . noise ) )
2024-04-18 01:39:35 +00:00
2024-07-16 00:59:48 +00:00
# do cleanup
2024-06-30 15:37:33 +00:00
for model in self . models :
if not isinstance ( model , dict ) :
continue
2024-11-11 18:40:41 +00:00
# to-do: prune unused keys in here too automatically
2024-10-17 22:06:48 +00:00
if " experimental " not in model or not model [ " experimental " ] :
model [ " experimental " ] = { }
2024-07-16 00:59:48 +00:00
if " prom_levels " in model :
2024-10-17 22:06:48 +00:00
_logger . warning ( f " Deprecated flag found: { ' cfg.model.prom_levels ' } " )
2024-07-16 00:59:48 +00:00
del model [ " prom_levels " ]
if " interleave " in model :
2024-10-17 22:06:48 +00:00
_logger . warning ( f " Deprecated flag found: { ' cfg.model.interleave ' } " )
2024-07-16 00:59:48 +00:00
del model [ " interleave " ]
2024-11-04 01:19:15 +00:00
if " p_rvq_levels " in model [ " experimental " ] :
model [ " experimental " ] [ " rvq_levels_p " ] = model [ " experimental " ] [ " p_rvq_levels " ]
del model [ " experimental " ] [ " p_rvq_levels " ]
2024-11-11 18:40:41 +00:00
if " p_len_train " in model [ " experimental " ] :
del model [ " experimental " ] [ " p_len_train " ]
2024-11-17 23:04:07 +00:00
if " masking_ratio_fixed " in model [ " experimental " ] :
del model [ " experimental " ] [ " masking_ratio_fixed " ]
2024-11-04 01:19:15 +00:00
2024-10-26 03:15:15 +00:00
self . models = [ Model ( * * model ) if isinstance ( model , dict ) else model for model in self . models ]
self . loras = [ LoRA ( * * lora ) if isinstance ( lora , dict ) else lora for lora in self . loras ]
2024-04-18 01:39:35 +00:00
2024-07-04 20:40:51 +00:00
if not self . models :
self . models = [ Model ( ) ]
2024-06-30 15:37:33 +00:00
for model in self . models :
2024-10-26 03:15:15 +00:00
if isinstance ( model . experimental , dict ) :
model . experimental = ModelExperimentalSettings ( * * model . experimental )
2023-08-16 02:58:16 +00:00
2024-05-10 01:28:20 +00:00
if self . hyperparameters . scheduler_type and not self . hyperparameters . scheduler :
self . hyperparameters . scheduler = self . hyperparameters . scheduler_type
self . hyperparameters . scheduler_type = " "
# do not combine the two
if self . hyperparameters . scheduler == " schedulefree " and self . optimizations . dadaptation :
self . hyperparameters . scheduler = " "
if self . hyperparameters . scheduler == " " :
self . hyperparameters . torch_scheduler = True
2024-05-11 21:31:05 +00:00
if self . trainer . backend == " local " and self . distributed :
self . trainer . ddp = True
2024-05-25 16:07:52 +00:00
2024-06-04 02:28:49 +00:00
if self . trainer . activation_checkpointing is not None :
self . trainer . gradient_checkpointing = self . trainer . activation_checkpointing
2024-06-09 22:11:38 +00:00
if not training :
self . dataset . use_hdf5 = False
2024-06-08 01:23:53 +00:00
# load our HDF5 file if requested here
if self . dataset . use_hdf5 :
self . load_hdf5 ( )
2024-06-09 22:11:38 +00:00
# load tokenizer
2024-10-17 22:06:48 +00:00
if self . tokenizer == " naive " :
self . tokenizer = NaiveTokenizer ( )
2024-06-19 03:11:14 +00:00
else :
2024-11-10 18:19:48 +00:00
from transformers import PreTrainedTokenizerFast
tokenizer_path = self . rel_path / self . tokenizer_path
# deduce path if a local copy is not provided
if not tokenizer_path . exists ( ) :
tokenizer_path = Path ( " ./data/ " ) / self . tokenizer_path
if not self . silent_errors and not tokenizer_path . exists ( ) :
raise Exception ( f ' Tokenizer path not found: { tokenizer_path } ' )
self . tokenizer = PreTrainedTokenizerFast ( tokenizer_file = str ( tokenizer_path ) )
2024-06-09 22:11:38 +00:00
2024-04-30 03:14:01 +00:00
# Preserves the old behavior
class NaiveTokenizer :
def get_vocab ( self ) :
"""
if cfg . dataset . use_hdf5 and ' symmap ' in cfg . hdf5 :
return json . loads ( cfg . hdf5 [ ' symmap ' ] . asstr ( ) [ ( ) ] )
"""
return { ' <s> ' : 1 , ' </s> ' : 2 , ' ' : 3 , ' . ' : 4 , ' , ' : 5 , ' ! ' : 6 , ' ? ' : 7 , ' p ' : 7 , ' iː ' : 8 , ' ɚ ' : 9 , ' ˌ ' : 10 , ' dˌ ' : 11 , ' mˌ ' : 12 , ' d ' : 13 , ' ɹ ' : 14 , ' tˈ ' : 15 , ' pˌ ' : 16 , ' uː ' : 17 , ' l ' : 18 , ' æ ' : 19 , ' ɛ ' : 20 , ' ɪ ' : 21 , ' j ' : 22 , ' ʊ ' : 23 , ' t ' : 24 , ' n ' : 25 , ' v ' : 26 , ' a ' : 27 , ' o ' : 28 , ' ŋ ' : 29 , ' w ' : 30 , ' ʌ ' : 31 , ' hˈ ' : 32 , ' ɡ ˈ ' : 33 , ' ə ' : 34 , ' θˈ ' : 35 , ' dˈ ' : 36 , ' wˌ ' : 37 , ' h ' : 38 , ' z ' : 39 , ' k ' : 40 , ' ð ' : 41 , ' ɡˌ ' : 42 , ' ˈ ' : 43 , ' fˈ ' : 44 , ' i ' : 45 , ' s ' : 46 , ' ʃ ' : 47 , ' wˈ ' : 48 , ' ðˈ ' : 49 , ' ɹˈ ' : 50 , ' lˈ ' : 51 , ' ɡ ' : 52 , ' oː ' : 53 , ' mˈ ' : 54 , ' e ' : 55 , ' ɑ ː ' : 56 , ' nˈ ' : 57 , ' m ' : 58 , ' θˌ ' : 59 , ' sˈ ' : 60 , ' f ' : 61 , ' ɔː ' : 62 , ' hˌ ' : 63 , ' b ' : 64 , ' jˈ ' : 65 , ' ɐ ' : 66 , ' ʒˈ ' : 67 , ' θ ' : 68 , ' bˈ ' : 69 , ' ɾ ' : 70 , ' ɜː ' : 71 , ' ʌˈ ' : 72 , ' ʃˌ ' : 73 , ' bˌ ' : 74 , ' kˈ ' : 75 , ' ɔ ' : 76 , ' zˈ ' : 77 , ' ᵻ ' : 78 , ' kˌ ' : 79 , ' vˈ ' : 80 , ' fˌ ' : 81 , ' ʒ ' : 82 , ' ʃˈ ' : 83 , ' ɹˌ ' : 84 , ' tˌ ' : 85 , ' pˈ ' : 86 , ' ðˌ ' : 87 , ' sˌ ' : 88 , ' nˌ ' : 89 , ' lˌ ' : 90 , ' ̩ ' : 91 , ' ʔ ' : 92 , ' vˌ ' : 93 , ' ɪ ˈ ' : 94 , ' " ' : 95 , ' ɪˌ ' : 96 , ' ʒˌ ' : 97 , ' uː ˌ ' : 98 , ' ʊˈ ' : 99 , ' jˌ ' : 100 , ' uː ˈ ' : 101 , ' iː ˈ ' : 102 , ' zˌ ' : 103 , ' .ˈ ' : 104 , ' … ' : 105 , ' ŋˌ ' : 106 , ' ɐˌ ' : 107 , ' —ˈ ' : 108 , ' iˌ ' : 109 , ' iː ˌ ' : 110 , ' ɛː ' : 111 , ' ) ' : 112 , ' )ˈ ' : 113 , ' ( ' : 114 , ' u ' : 115 , ' - ' : 116 , ' ɖˈ ' : 117 , ' iˈ ' : 118 , ' ʰˈ ' : 119 , ' ɟˈ ' : 120 , ' ̃ ' : 121 , ' eː ' : 122 , ' ɾˈ ' : 123 , ' r ' : 124 , ' ʰ ' : 125 , ' -ˌ ' : 126 , ' ɫ ' : 127 , ' q ' : 128 , ' — ' : 129 , ' ʊˌ ' : 130 , ' aː ' : 131 , ' cˈ ' : 132 , ' …ˈ ' : 133 , ' c ' : 134 , ' ɳ ' : 135 , ' ɐˈ ' : 136 , ' x ' : 137 , ' ʔˌ ' : 138 , ' .ˌ ' : 139 , ' ɑ ' : 140 , ' ?ˈ ' : 141 , ' ̩ˈ ' : 142 , ' " ˈ ' : 143 , ' ,ˈ ' : 144 , ' ŋˈ ' : 145 , ' əˌ ' : 146 , ' !ˈ ' : 147 , ' " ˌ ' : 148 , ' ?ˌ ' : 149 , ' ,ˌ ' : 150 , ' —ˌ ' : 151 , ' ̩ˌ ' : 152 , ' əˈ ' : 153 , ' !ˌ ' : 154 , ' ɬ ' : 155 , ' ʲ ' : 156 , ' ¡ ' : 157 , ' ɯ ' : 158 , ' qˌ ' : 159 , ' ʑ ' : 160 , ' ʑˈ ' : 161 , ' ¿ ' : 162 , ' ɑ ː ˈ ' : 163 , ' iː ː ' : 164 , ' ɛˈ ' : 165 , ' ¡ˈ ' : 166 , ' æˈ ' : 167 , ' ç ' : 168 , ' ɾˌ ' : 169 , ' ᵻˈ ' : 170 , ' xˈ ' : 171 , ' ɔːˈ ' : 172 , ' ; ' : 173 , ' ɬˌ ' : 174 , ' : ' : 175 , ' ʔ ˈ ' : 176 , ' ɑːˌ ' : 177 , ' ɬˈ ' : 178 , ' ” ' : 179 , ' “ ' : 180 , ' “ˈ ' : 181 , ' “ˌ ' : 182 , ' ;ˈ ' : 183 , ' ;ˌ ' : 184 , ' :ˈ ' : 185 , ' 1 ' : 186 , ' rˈ ' : 187 , ' qˈ ' : 188 , ' ᵻˌ ' : 189 , ' ä ' : 190 , ' ̞ˌ ' : 191 , ' ̞ ' : 192 , ' ũˌ ' : 193 , ' ʑˌ ' : 194 , ' ᵝ ' : 195 , ' ɽ ' : 196 , ' ʲˌ ' : 197 , ' ᵝˌ ' : 198 , ' ũ ' : 199 , ' ũˈ ' : 200 , ' äˌ ' : 201 , ' ɕ ' : 202 , ' ɕˌ ' : 203 , ' ɽˌ ' : 204 , ' çˌ ' : 205 , ' …ˌ ' : 206 , ' ̞ˈ ' : 207 , ' äˈ ' : 208 , ' ɽˈ ' : 209 , ' ɸˌ ' : 210 , ' ɴ ' : 211 , ' ɸˈ ' : 212 , ' ɕˈ ' : 213 , ' ɸ ' : 214 , ' ᵝˈ ' : 215 , ' ʲˈ ' : 216 , ' ĩ ' : 217 , ' çˈ ' : 218 , ' ĩˌ ' : 219 , ' oˌ ' : 220 , ' eˈ ' : 221 , ' ʍ ' : 222 , ' eˌ ' : 223 , ' uˌ ' : 224 , ' ʍˌ ' : 225 , ' uˈ ' : 226 , ' oˈ ' : 227 , ' aˈ ' : 228 }
2024-07-18 21:16:14 +00:00
@cached_property
def _bos_token ( self ) :
return self . get_vocab ( ) [ " <s> " ]
@cached_property
def _eos_token ( self ) :
return self . get_vocab ( ) [ " </s> " ]
2024-04-30 03:14:01 +00:00
def encode ( self , s ) :
symmap = self . get_vocab ( )
phones = " " . join ( list ( s ) )
# do merge
for merge in [ " \u02C8 " , " \u02CC " , " \u02D0 " ] :
phones = phones . replace ( f ' { merge } ' , merge )
phones = phones . split ( " " )
# cleanup
phones = [ p for i , p in enumerate ( phones ) if p not in [ " " ] or ( p in [ " " ] and p != phones [ i - 1 ] ) ]
# add bos / eos
phones = [ " <s> " ] + [ " " if not p else p for p in phones ] + [ " </s> " ]
# tokenize
return [ * map ( symmap . get , phones ) ]
2024-09-06 01:43:20 +00:00
def decode ( self , t ) :
s = " "
symmap = self . get_vocab ( )
reverse_symmap = { }
for k , v in symmap . items ( ) :
reverse_symmap [ v ] = k
for i , token in enumerate ( t ) :
s + = reverse_symmap [ token ]
return s
2024-08-29 18:27:16 +00:00
_logger = logging . getLogger ( __name__ )
2023-08-02 21:53:35 +00:00
cfg = Config . from_cli ( )
2024-06-08 01:23:53 +00:00
# some safety for remapping deprecated formats and re-coercing uninitialized properties into actual types
2023-08-14 03:07:45 +00:00
try :
2023-08-16 02:58:16 +00:00
cfg . format ( )
2023-08-14 03:07:45 +00:00
except Exception as e :
2024-11-10 18:19:48 +00:00
if not cfg . silent_errors :
raise e # throw an error because I'm tired of silent errors messing things up for me
2024-08-29 18:27:16 +00:00
_logger . error ( f " Error while parsing config YAML: { str ( e ) } " )
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
if __name__ == " __main__ " :
2023-08-14 03:07:45 +00:00
print ( cfg )