cleaned up unused config flags, allow less strict yaml by pruning missing keys, renamed some dataset configs to be more unified

This commit is contained in:
mrq 2024-10-17 17:06:48 -05:00
parent 8b6095f681
commit 75b90be325
8 changed files with 113 additions and 109 deletions

View File

@ -21,15 +21,7 @@ from functools import cached_property
from pathlib import Path
from .utils.distributed import world_size
def set_seed(seed=None):
if not seed:
seed = time.time()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
from .utils import set_seed, prune_missing
@dataclass()
class BaseConfig:
@ -37,7 +29,7 @@ class BaseConfig:
@property
def cfg_path(self):
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
return Path(self.yaml_path.parent) if self.yaml_path is not None else Path(__file__).parent.parent / "data"
@property
def rel_path(self):
@ -95,11 +87,24 @@ class BaseConfig:
with open(path, "w") as f:
f.write(self.dumps())
# ick
@classmethod
def prune_missing( cls, yaml ):
default = cls(**{})
default.format()
#default = json.loads(default.dumps())
yaml, missing = prune_missing( source=default, dest=yaml )
if missing:
_logger.warning(f'Missing keys in YAML: {missing}')
return yaml
@classmethod
def from_yaml( cls, yaml_path ):
state = {}
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
state.setdefault("yaml_path", yaml_path)
state = cls.prune_missing( state )
return cls(**state)
@classmethod
@ -130,52 +135,48 @@ class Dataset:
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset
temp: list[Path] = field(default_factory=lambda: []) # for when I need to yank things out of a training dataset without cutting it out
# to-do: replace these since I feel this can be a bottleneck
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path
# to-do: validate if I can ignore this since this is an artifact from when I only saved phonemes and encoded audio, and no metadata
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
use_hdf5: bool = False # whether to load from an HDF5 dataset
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
workers: int = 8 # number of dataloader workers to spawn
cache: bool = True # use diskcache to cache the dataset
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
# to-do: clean up the following block, it's a mess
min_utterances: int = 2 # minimum number of utterances a speaker can have
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
prompt_duration: float | None = None # legacy
max_resps: int = 1 # number of samples to target for training
p_resp_append: float = 1.0 # probability to append another sample to the training target
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
sample_type: str = "path" # path | speaker
sample_order: str = "interleaved" # duration
sample_shuffle: bool = True # shuffles the indices in the sampler
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
prom_sample_similar: bool = True # if available, try and sample the prompt closest to the sampled response utterance (requires specific metadata generated)
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
prompt_max_samples: int = 3 # maximum number of utterances that can be included in an input prompt for training
prompt_continuous_utterance_p: float = 0.0 # probability to use the target utterance as an input prompt rather than using a different utterance
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
resps_max_samples: int = 1 # number of samples to target for training
resps_append_p: float = 1.0 # probability to append another sample to the training target
resps_pad_silence_p: float = 0.0 # probability to pad resp with silence to fit within the next window
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
noise_scale: float = 0.25 # scaling noise value
inject_noise_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
noise_inject_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
retokenize_text: bool = False
_frames_per_second: int = 0 # allows setting your own hint
@ -219,7 +220,7 @@ class ModelExperimentalSettings:
# in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
kv_heads: int = 0 # MHA or GQA (for supported backends)
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_levels_p: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
@ -431,8 +432,6 @@ class Evaluation:
nar_temperature: float = 0.0 # NAR temp for inferencing
nar_levels: int = 0 # maximum NAR levels to use for inferencing
load_disabled_engines: bool = True # see the other load_disabled_engines
@dataclass()
class DeepSpeed:
zero_optimization_level: int = 0 # doesn't seem to work
@ -617,10 +616,8 @@ class Trainer:
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
aggressive_optimizations: bool = False # deprecated
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
weight_dtype: str = "float16" # dtype to have the model under
@ -628,8 +625,7 @@ class Trainer:
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
load_webui: bool = False # load the web UI to allow inferencing during training, to-do: actually make this work
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
@ -657,13 +653,7 @@ class Inference:
weight_dtype: str = "float32" # dtype to load the model under
amp: bool = False # automatic mixed precision during inferencing
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
# legacy / backwards compat
audio_backend: str = "" # encodec, vocos, dac
use_vocos: bool = True
use_encodec: bool = True
use_dac: bool = True
normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though
@property
def dtype(self):
@ -694,6 +684,7 @@ class Optimizations:
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
# to-do: validate this madness works still, I don't remember what schizodemon told me to do this
model_offloading: dict | None = None # automatically splits the model over a list of devices
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
@ -705,7 +696,7 @@ class Optimizations:
class Config(BaseConfig):
device: str = "cuda" # target device
mode: str = "training" # "inferencing"
experimental: bool = False # Debug flag, unused now
experimental: bool = False # debug flag
dataset: Dataset = field(default_factory=lambda: Dataset)
models: dict | list | None = field(default_factory=lambda: [])
@ -714,7 +705,6 @@ class Config(BaseConfig):
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
tokenizer: str | None = None # tokenizer class
@ -828,7 +818,6 @@ class Config(BaseConfig):
return path
# to-do: prune unused keys
def format( self, training=True ):
if isinstance(self.dataset, type):
self.dataset = dict()
@ -869,19 +858,25 @@ class Config(BaseConfig):
if not isinstance( model, dict ):
continue
if "prom_levels" in model:
del model["prom_levels"]
if "interleave" in model:
del model["interleave"]
if "audio_embedding_sums" not in model:
continue
if "experimental" not in model or not model["experimental"]:
model["experimental"] = {}
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
if "prom_levels" in model:
_logger.warning(f"Deprecated flag found: {'cfg.model.prom_levels'}")
del model["prom_levels"]
if "interleave" in model:
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
del model["interleave"]
if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]:
_logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}")
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
del model["experimental"]["p_rvq_levels"]
if "audio_embedding_sums" in model:
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
self.models = [ Model(**model) for model in self.models ]
@ -905,11 +900,7 @@ class Config(BaseConfig):
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
self.inference = Inference(**self.inference)
if self.bitsandbytes is not None:
self.optimizations = Optimizations(**self.bitsandbytes)
else:
self.optimizations = Optimizations(**self.optimizations)
self.optimizations = Optimizations(**self.optimizations)
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
@ -922,15 +913,9 @@ class Config(BaseConfig):
if self.hyperparameters.scheduler == "":
self.hyperparameters.torch_scheduler = True
if self.dataset.prompt_duration is not None:
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
if self.trainer.backend == "local" and self.distributed:
self.trainer.ddp = True
if self.inference.audio_backend != "" and self.audio_backend == "":
self.audio_backend = self.inference.audio_backend
if self.trainer.activation_checkpointing is not None:
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
@ -942,22 +927,23 @@ class Config(BaseConfig):
self.load_hdf5()
# load tokenizer
if cfg.tokenizer == "naive":
cfg.tokenizer = NaiveTokenizer()
if self.tokenizer == "naive":
self.tokenizer = NaiveTokenizer()
else:
# ick...
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
tokenizer_path = self.rel_path / self.tokenizer_path if self.yaml_path is not None else None
if tokenizer_path and not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
tokenizer_path = Path("./data/") / self.tokenizer_path
if tokenizer_path and tokenizer_path.exists():
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
else:
cfg.tokenizer = NaiveTokenizer()
self.tokenizer = NaiveTokenizer()
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
self.tokenizer = NaiveTokenizer()
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
pass

View File

@ -1012,10 +1012,12 @@ class Dataset(_Dataset):
prom_length = 0
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
for _ in range(cfg.dataset.max_prompts):
if reference is not None and cfg.dataset.prom_sample_similar:
path = self.get_similar_utterance( reference, offset = len(prom_list) ) if random.random() < cfg.dataset.prompt_similar_p else random.choice(choices)
for _ in range(cfg.dataset.prompt_max_samples):
if reference is not None:
# yuck
path = None
if random.random() < cfg.dataset.prompt_similar_p:
path = self.get_similar_utterance( reference, offset = len(prom_list) )
if not path:
path = random.choice(choices)
else:
@ -1032,7 +1034,7 @@ class Dataset(_Dataset):
prom_list.append(qnt)
prom_length += qnt.shape[0]
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
if prom_length >= trim_length:
break
# might be better to decode => concat waveforms with silence in between => reencode
@ -1113,9 +1115,9 @@ class Dataset(_Dataset):
naive = cfg.experimental
# append additional prompts in an attempt to artifically increase lengths / offer new data
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p:
ignore_paths = []
for _ in range( 1, cfg.dataset.max_resps ):
for _ in range( 1, cfg.dataset.resps_max_samples ):
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
ignore_paths.append(path)
@ -1316,7 +1318,7 @@ class Dataset(_Dataset):
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# pad the target with silence
if random.random() < cfg.dataset.p_resp_pad_silence:
if random.random() < cfg.dataset.resps_pad_silence_p:
resps = pad_codes_with_silence( resps )
return dict(

View File

@ -525,9 +525,6 @@ class Engines(dict[str, Engine]):
n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu')
if not cfg.trainer.check_for_oom:
engine.backward(loss)
else:

View File

@ -79,7 +79,7 @@ class AR(Base):
# is training
if training:
# specifies how to sample probabilities of which RVQ levels to train against
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
# determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
# rate to perform token dropout errors
@ -90,19 +90,19 @@ class AR(Base):
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
# allow passing a specific distribution of RVQ levels
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not p_rvq_levels:
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
if not rvq_levels_p:
lo, hi = quant_level_range[0], quant_level_range[1] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if p_rvq_levels == "equal":
p_rvq_levels = [ i for i in range( lo, hi ) ]
if rvq_levels_p == "equal":
rvq_levels_p = [ i for i in range( lo, hi ) ]
else:
# yuck
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
# input RVQ levels
if not self.interleave:
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
# trim resps to only contain all levels below the target level
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
else:

View File

@ -102,7 +102,7 @@ class AR_NAR(Base):
# is training
if training:
# specifies how to sample probabilities of which RVQ levels to train against
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
# determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
# rate to perform token dropout errors
@ -113,18 +113,18 @@ class AR_NAR(Base):
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
# allow passing a specific distribution of RVQ levels
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not p_rvq_levels:
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
if not rvq_levels_p:
lo, hi = quant_level_range[0], quant_level_range[1] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if p_rvq_levels == "equal":
p_rvq_levels = [ i for i in range( lo, hi ) ]
if rvq_levels_p == "equal":
rvq_levels_p = [ i for i in range( lo, hi ) ]
else:
# yuck
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
# input RVQ levels
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
for i, task in enumerate( task_list ):
if task in text_task:
quant_levels[i] = 0 # self.n_resp_levels - 1

View File

@ -75,7 +75,7 @@ class NAR(Base):
task_list = [ sample_task() for _ in range(batch_size) ]
# specifies how to sample probabilities of which RVQ levels to train against
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
# determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
# rate to perform token dropout errors
@ -86,18 +86,18 @@ class NAR(Base):
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
# allow passing a specific distribution of RVQ levels
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
if not p_rvq_levels:
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
if not rvq_levels_p:
lo, hi = quant_level_range[0], quant_level_range[1] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if p_rvq_levels == "equal":
p_rvq_levels = [ i for i in range( lo, hi ) ]
if rvq_levels_p == "equal":
rvq_levels_p = [ i for i in range( lo, hi ) ]
else:
# yuck
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
# input RVQ levels
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
# trim resps to only contain all levels below the target level
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]

View File

@ -11,5 +11,6 @@ from .utils import (
passes_policy,
get_devices,
truncate_json,
timer
timer,
prune_missing
)

View File

@ -31,6 +31,24 @@ from datetime import datetime
T = TypeVar("T")
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
is_obj = hasattr( source, "__dict__" )
if parent_is_obj is None:
parent_is_obj = is_obj
haystack = source.__dict__ if is_obj else source
keep = {}
missing = []
for k, v in dest.items():
if k in haystack or (parent_is_obj and not is_obj and source == {}):
keep[k] = dest[k]
else:
missing.append(".".join(path + [k]))
if recurse and isinstance( v, dict ):
keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing )
missing += m
return (keep, missing) if return_missing else keep
class timer:
def __init__(self, msg="Elapsed time:", callback=None):
self.msg = msg