cleaned up unused config flags, allow less strict yaml by pruning missing keys, renamed some dataset configs to be more unified
This commit is contained in:
parent
8b6095f681
commit
75b90be325
140
vall_e/config.py
140
vall_e/config.py
|
@ -21,15 +21,7 @@ from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .utils.distributed import world_size
|
from .utils.distributed import world_size
|
||||||
|
from .utils import set_seed, prune_missing
|
||||||
|
|
||||||
def set_seed(seed=None):
|
|
||||||
if not seed:
|
|
||||||
seed = time.time()
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
|
@ -37,7 +29,7 @@ class BaseConfig:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cfg_path(self):
|
def cfg_path(self):
|
||||||
return Path(self.yaml_path.parent) if self.yaml_path is not None else None
|
return Path(self.yaml_path.parent) if self.yaml_path is not None else Path(__file__).parent.parent / "data"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rel_path(self):
|
def rel_path(self):
|
||||||
|
@ -95,11 +87,24 @@ class BaseConfig:
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(self.dumps())
|
f.write(self.dumps())
|
||||||
|
|
||||||
|
# ick
|
||||||
|
@classmethod
|
||||||
|
def prune_missing( cls, yaml ):
|
||||||
|
default = cls(**{})
|
||||||
|
default.format()
|
||||||
|
#default = json.loads(default.dumps())
|
||||||
|
|
||||||
|
yaml, missing = prune_missing( source=default, dest=yaml )
|
||||||
|
if missing:
|
||||||
|
_logger.warning(f'Missing keys in YAML: {missing}')
|
||||||
|
return yaml
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_yaml( cls, yaml_path ):
|
def from_yaml( cls, yaml_path ):
|
||||||
state = {}
|
state = {}
|
||||||
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8"))
|
||||||
state.setdefault("yaml_path", yaml_path)
|
state.setdefault("yaml_path", yaml_path)
|
||||||
|
state = cls.prune_missing( state )
|
||||||
return cls(**state)
|
return cls(**state)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -130,52 +135,48 @@ class Dataset:
|
||||||
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
|
validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset
|
||||||
noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset
|
noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset
|
||||||
|
|
||||||
temp: list[Path] = field(default_factory=lambda: []) # for when I need to yank things out of a training dataset without cutting it out
|
# to-do: replace these since I feel this can be a bottleneck
|
||||||
|
|
||||||
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path
|
speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path
|
||||||
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path
|
speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path
|
||||||
|
# to-do: validate if I can ignore this since this is an artifact from when I only saved phonemes and encoded audio, and no metadata
|
||||||
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
|
speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups
|
||||||
|
|
||||||
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
|
|
||||||
use_hdf5: bool = False # whether to load from an HDF5 dataset
|
use_hdf5: bool = False # whether to load from an HDF5 dataset
|
||||||
|
hdf5_name: str = "data.h5" # file name to load the HDF5 dataset
|
||||||
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
|
hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways
|
||||||
|
|
||||||
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
|
use_metadata: bool = False # use genretaed metadata to aid in dataset loading
|
||||||
|
|
||||||
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
validate: bool = True # validate each utterance on wheter it can be included based on duration range caps
|
||||||
workers: int = 8 # number of dataloader workers to spawn
|
workers: int = 8 # number of dataloader workers to spawn
|
||||||
cache: bool = True # use diskcache to cache the dataset
|
cache: bool = True # use diskcache to cache the dataset
|
||||||
|
|
||||||
phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset
|
|
||||||
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
|
||||||
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
|
|
||||||
|
|
||||||
# to-do: clean up the following block, it's a mess
|
|
||||||
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
min_utterances: int = 2 # minimum number of utterances a speaker can have
|
||||||
random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt
|
duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset
|
||||||
max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training
|
|
||||||
prompt_duration: float | None = None # legacy
|
|
||||||
max_resps: int = 1 # number of samples to target for training
|
|
||||||
p_resp_append: float = 1.0 # probability to append another sample to the training target
|
|
||||||
p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window
|
|
||||||
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
|
|
||||||
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
|
|
||||||
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
|
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
sample_order: str = "interleaved" # duration
|
sample_order: str = "interleaved" # duration
|
||||||
|
sample_shuffle: bool = True # shuffles the indices in the sampler
|
||||||
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||||
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
|
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
|
||||||
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
|
|
||||||
|
|
||||||
prom_sample_similar: bool = True # if available, try and sample the prompt closest to the sampled response utterance (requires specific metadata generated)
|
prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be
|
||||||
|
prompt_max_samples: int = 3 # maximum number of utterances that can be included in an input prompt for training
|
||||||
|
prompt_continuous_utterance_p: float = 0.0 # probability to use the target utterance as an input prompt rather than using a different utterance
|
||||||
|
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
|
||||||
|
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
|
||||||
|
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
|
||||||
|
|
||||||
|
resps_max_samples: int = 1 # number of samples to target for training
|
||||||
|
resps_append_p: float = 1.0 # probability to append another sample to the training target
|
||||||
|
resps_pad_silence_p: float = 0.0 # probability to pad resp with silence to fit within the next window
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
||||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||||
noise_scale: float = 0.25 # scaling noise value
|
noise_scale: float = 0.25 # scaling noise value
|
||||||
inject_noise_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
|
noise_inject_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
|
||||||
retokenize_text: bool = False
|
retokenize_text: bool = False
|
||||||
|
|
||||||
_frames_per_second: int = 0 # allows setting your own hint
|
_frames_per_second: int = 0 # allows setting your own hint
|
||||||
|
@ -219,7 +220,7 @@ class ModelExperimentalSettings:
|
||||||
# in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so
|
# in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so
|
||||||
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings
|
||||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||||
p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
rvq_levels_p: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely
|
||||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
|
||||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||||
|
@ -431,8 +432,6 @@ class Evaluation:
|
||||||
nar_temperature: float = 0.0 # NAR temp for inferencing
|
nar_temperature: float = 0.0 # NAR temp for inferencing
|
||||||
nar_levels: int = 0 # maximum NAR levels to use for inferencing
|
nar_levels: int = 0 # maximum NAR levels to use for inferencing
|
||||||
|
|
||||||
load_disabled_engines: bool = True # see the other load_disabled_engines
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class DeepSpeed:
|
class DeepSpeed:
|
||||||
zero_optimization_level: int = 0 # doesn't seem to work
|
zero_optimization_level: int = 0 # doesn't seem to work
|
||||||
|
@ -617,10 +616,8 @@ class Trainer:
|
||||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||||
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||||
|
|
||||||
aggressive_optimizations: bool = False # deprecated
|
|
||||||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||||
load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation
|
|
||||||
|
|
||||||
weight_dtype: str = "float16" # dtype to have the model under
|
weight_dtype: str = "float16" # dtype to have the model under
|
||||||
|
|
||||||
|
@ -628,8 +625,7 @@ class Trainer:
|
||||||
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested
|
||||||
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
|
#scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload)
|
||||||
|
|
||||||
load_webui: bool = False # not working, but loads the web UI to allow inferencing during training
|
load_webui: bool = False # load the web UI to allow inferencing during training, to-do: actually make this work
|
||||||
no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet
|
|
||||||
|
|
||||||
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
|
backend: str = "local" # training backend to use. currently supports "local" | "deepspeed"
|
||||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings
|
||||||
|
@ -657,13 +653,7 @@ class Inference:
|
||||||
weight_dtype: str = "float32" # dtype to load the model under
|
weight_dtype: str = "float32" # dtype to load the model under
|
||||||
amp: bool = False # automatic mixed precision during inferencing
|
amp: bool = False # automatic mixed precision during inferencing
|
||||||
|
|
||||||
normalize: bool = False # do NOT enable this unless you know exactly what you're doing
|
normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though
|
||||||
|
|
||||||
# legacy / backwards compat
|
|
||||||
audio_backend: str = "" # encodec, vocos, dac
|
|
||||||
use_vocos: bool = True
|
|
||||||
use_encodec: bool = True
|
|
||||||
use_dac: bool = True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
|
@ -694,6 +684,7 @@ class Optimizations:
|
||||||
bitnet: bool = False # use bitnet
|
bitnet: bool = False # use bitnet
|
||||||
fp8: bool = False # use fp8
|
fp8: bool = False # use fp8
|
||||||
|
|
||||||
|
# to-do: validate this madness works still, I don't remember what schizodemon told me to do this
|
||||||
model_offloading: dict | None = None # automatically splits the model over a list of devices
|
model_offloading: dict | None = None # automatically splits the model over a list of devices
|
||||||
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
|
# example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU
|
||||||
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
# example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50%
|
||||||
|
@ -705,7 +696,7 @@ class Optimizations:
|
||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
device: str = "cuda" # target device
|
device: str = "cuda" # target device
|
||||||
mode: str = "training" # "inferencing"
|
mode: str = "training" # "inferencing"
|
||||||
experimental: bool = False # Debug flag, unused now
|
experimental: bool = False # debug flag
|
||||||
|
|
||||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||||
models: dict | list | None = field(default_factory=lambda: [])
|
models: dict | list | None = field(default_factory=lambda: [])
|
||||||
|
@ -714,7 +705,6 @@ class Config(BaseConfig):
|
||||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||||
inference: Inference = field(default_factory=lambda: Inference)
|
inference: Inference = field(default_factory=lambda: Inference)
|
||||||
bitsandbytes: dict | list | None = None # deprecated
|
|
||||||
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
|
||||||
|
|
||||||
tokenizer: str | None = None # tokenizer class
|
tokenizer: str | None = None # tokenizer class
|
||||||
|
@ -828,7 +818,6 @@ class Config(BaseConfig):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
# to-do: prune unused keys
|
|
||||||
def format( self, training=True ):
|
def format( self, training=True ):
|
||||||
if isinstance(self.dataset, type):
|
if isinstance(self.dataset, type):
|
||||||
self.dataset = dict()
|
self.dataset = dict()
|
||||||
|
@ -869,19 +858,25 @@ class Config(BaseConfig):
|
||||||
if not isinstance( model, dict ):
|
if not isinstance( model, dict ):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "prom_levels" in model:
|
|
||||||
del model["prom_levels"]
|
|
||||||
|
|
||||||
if "interleave" in model:
|
|
||||||
del model["interleave"]
|
|
||||||
|
|
||||||
if "audio_embedding_sums" not in model:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if "experimental" not in model or not model["experimental"]:
|
if "experimental" not in model or not model["experimental"]:
|
||||||
model["experimental"] = {}
|
model["experimental"] = {}
|
||||||
|
|
||||||
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
if "prom_levels" in model:
|
||||||
|
_logger.warning(f"Deprecated flag found: {'cfg.model.prom_levels'}")
|
||||||
|
del model["prom_levels"]
|
||||||
|
|
||||||
|
if "interleave" in model:
|
||||||
|
_logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}")
|
||||||
|
del model["interleave"]
|
||||||
|
|
||||||
|
if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]:
|
||||||
|
_logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}")
|
||||||
|
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
|
||||||
|
del model["experimental"]["p_rvq_levels"]
|
||||||
|
|
||||||
|
if "audio_embedding_sums" in model:
|
||||||
|
_logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}")
|
||||||
|
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
|
||||||
|
|
||||||
|
|
||||||
self.models = [ Model(**model) for model in self.models ]
|
self.models = [ Model(**model) for model in self.models ]
|
||||||
|
@ -905,11 +900,7 @@ class Config(BaseConfig):
|
||||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||||
|
|
||||||
self.inference = Inference(**self.inference)
|
self.inference = Inference(**self.inference)
|
||||||
|
self.optimizations = Optimizations(**self.optimizations)
|
||||||
if self.bitsandbytes is not None:
|
|
||||||
self.optimizations = Optimizations(**self.bitsandbytes)
|
|
||||||
else:
|
|
||||||
self.optimizations = Optimizations(**self.optimizations)
|
|
||||||
|
|
||||||
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler:
|
||||||
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
self.hyperparameters.scheduler = self.hyperparameters.scheduler_type
|
||||||
|
@ -922,15 +913,9 @@ class Config(BaseConfig):
|
||||||
if self.hyperparameters.scheduler == "":
|
if self.hyperparameters.scheduler == "":
|
||||||
self.hyperparameters.torch_scheduler = True
|
self.hyperparameters.torch_scheduler = True
|
||||||
|
|
||||||
if self.dataset.prompt_duration is not None:
|
|
||||||
self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration]
|
|
||||||
|
|
||||||
if self.trainer.backend == "local" and self.distributed:
|
if self.trainer.backend == "local" and self.distributed:
|
||||||
self.trainer.ddp = True
|
self.trainer.ddp = True
|
||||||
|
|
||||||
if self.inference.audio_backend != "" and self.audio_backend == "":
|
|
||||||
self.audio_backend = self.inference.audio_backend
|
|
||||||
|
|
||||||
if self.trainer.activation_checkpointing is not None:
|
if self.trainer.activation_checkpointing is not None:
|
||||||
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
|
self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing
|
||||||
|
|
||||||
|
@ -942,22 +927,23 @@ class Config(BaseConfig):
|
||||||
self.load_hdf5()
|
self.load_hdf5()
|
||||||
|
|
||||||
# load tokenizer
|
# load tokenizer
|
||||||
if cfg.tokenizer == "naive":
|
if self.tokenizer == "naive":
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
self.tokenizer = NaiveTokenizer()
|
||||||
else:
|
else:
|
||||||
|
# ick...
|
||||||
try:
|
try:
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
|
tokenizer_path = self.rel_path / self.tokenizer_path if self.yaml_path is not None else None
|
||||||
if tokenizer_path and not tokenizer_path.exists():
|
if tokenizer_path and not tokenizer_path.exists():
|
||||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
tokenizer_path = Path("./data/") / self.tokenizer_path
|
||||||
|
|
||||||
if tokenizer_path and tokenizer_path.exists():
|
if tokenizer_path and tokenizer_path.exists():
|
||||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||||
else:
|
else:
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
self.tokenizer = NaiveTokenizer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cfg.tokenizer = NaiveTokenizer()
|
self.tokenizer = NaiveTokenizer()
|
||||||
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
|
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -1012,10 +1012,12 @@ class Dataset(_Dataset):
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.prompt_max_samples):
|
||||||
if reference is not None and cfg.dataset.prom_sample_similar:
|
if reference is not None:
|
||||||
path = self.get_similar_utterance( reference, offset = len(prom_list) ) if random.random() < cfg.dataset.prompt_similar_p else random.choice(choices)
|
|
||||||
# yuck
|
# yuck
|
||||||
|
path = None
|
||||||
|
if random.random() < cfg.dataset.prompt_similar_p:
|
||||||
|
path = self.get_similar_utterance( reference, offset = len(prom_list) )
|
||||||
if not path:
|
if not path:
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
else:
|
else:
|
||||||
|
@ -1032,7 +1034,7 @@ class Dataset(_Dataset):
|
||||||
prom_list.append(qnt)
|
prom_list.append(qnt)
|
||||||
prom_length += qnt.shape[0]
|
prom_length += qnt.shape[0]
|
||||||
|
|
||||||
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
|
if prom_length >= trim_length:
|
||||||
break
|
break
|
||||||
|
|
||||||
# might be better to decode => concat waveforms with silence in between => reencode
|
# might be better to decode => concat waveforms with silence in between => reencode
|
||||||
|
@ -1113,9 +1115,9 @@ class Dataset(_Dataset):
|
||||||
naive = cfg.experimental
|
naive = cfg.experimental
|
||||||
|
|
||||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||||
if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append:
|
if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p:
|
||||||
ignore_paths = []
|
ignore_paths = []
|
||||||
for _ in range( 1, cfg.dataset.max_resps ):
|
for _ in range( 1, cfg.dataset.resps_max_samples ):
|
||||||
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
||||||
ignore_paths.append(path)
|
ignore_paths.append(path)
|
||||||
|
|
||||||
|
@ -1316,7 +1318,7 @@ class Dataset(_Dataset):
|
||||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||||
|
|
||||||
# pad the target with silence
|
# pad the target with silence
|
||||||
if random.random() < cfg.dataset.p_resp_pad_silence:
|
if random.random() < cfg.dataset.resps_pad_silence_p:
|
||||||
resps = pad_codes_with_silence( resps )
|
resps = pad_codes_with_silence( resps )
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
|
|
@ -525,9 +525,6 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
n_ooms = torch.zeros([], device=device)
|
n_ooms = torch.zeros([], device=device)
|
||||||
|
|
||||||
if cfg.trainer.aggressive_optimizations:
|
|
||||||
batch = to_device(batch, 'cpu')
|
|
||||||
|
|
||||||
if not cfg.trainer.check_for_oom:
|
if not cfg.trainer.check_for_oom:
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -79,7 +79,7 @@ class AR(Base):
|
||||||
# is training
|
# is training
|
||||||
if training:
|
if training:
|
||||||
# specifies how to sample probabilities of which RVQ levels to train against
|
# specifies how to sample probabilities of which RVQ levels to train against
|
||||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||||
# rate to perform token dropout errors
|
# rate to perform token dropout errors
|
||||||
|
@ -90,19 +90,19 @@ class AR(Base):
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||||
# allow passing a specific distribution of RVQ levels
|
# allow passing a specific distribution of RVQ levels
|
||||||
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
|
||||||
if not p_rvq_levels:
|
if not rvq_levels_p:
|
||||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
if p_rvq_levels == "equal":
|
if rvq_levels_p == "equal":
|
||||||
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
rvq_levels_p = [ i for i in range( lo, hi ) ]
|
||||||
else:
|
else:
|
||||||
# yuck
|
# yuck
|
||||||
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||||
|
|
||||||
# input RVQ levels
|
# input RVQ levels
|
||||||
if not self.interleave:
|
if not self.interleave:
|
||||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
|
||||||
# trim resps to only contain all levels below the target level
|
# trim resps to only contain all levels below the target level
|
||||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -102,7 +102,7 @@ class AR_NAR(Base):
|
||||||
# is training
|
# is training
|
||||||
if training:
|
if training:
|
||||||
# specifies how to sample probabilities of which RVQ levels to train against
|
# specifies how to sample probabilities of which RVQ levels to train against
|
||||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||||
# rate to perform token dropout errors
|
# rate to perform token dropout errors
|
||||||
|
@ -113,18 +113,18 @@ class AR_NAR(Base):
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||||
# allow passing a specific distribution of RVQ levels
|
# allow passing a specific distribution of RVQ levels
|
||||||
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
|
||||||
if not p_rvq_levels:
|
if not rvq_levels_p:
|
||||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
if p_rvq_levels == "equal":
|
if rvq_levels_p == "equal":
|
||||||
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
rvq_levels_p = [ i for i in range( lo, hi ) ]
|
||||||
else:
|
else:
|
||||||
# yuck
|
# yuck
|
||||||
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||||
|
|
||||||
# input RVQ levels
|
# input RVQ levels
|
||||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
|
||||||
for i, task in enumerate( task_list ):
|
for i, task in enumerate( task_list ):
|
||||||
if task in text_task:
|
if task in text_task:
|
||||||
quant_levels[i] = 0 # self.n_resp_levels - 1
|
quant_levels[i] = 0 # self.n_resp_levels - 1
|
||||||
|
|
|
@ -75,7 +75,7 @@ class NAR(Base):
|
||||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||||
|
|
||||||
# specifies how to sample probabilities of which RVQ levels to train against
|
# specifies how to sample probabilities of which RVQ levels to train against
|
||||||
p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal"
|
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||||
# rate to perform token dropout errors
|
# rate to perform token dropout errors
|
||||||
|
@ -86,18 +86,18 @@ class NAR(Base):
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||||
# allow passing a specific distribution of RVQ levels
|
# allow passing a specific distribution of RVQ levels
|
||||||
p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else []
|
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
|
||||||
if not p_rvq_levels:
|
if not rvq_levels_p:
|
||||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
if p_rvq_levels == "equal":
|
if rvq_levels_p == "equal":
|
||||||
p_rvq_levels = [ i for i in range( lo, hi ) ]
|
rvq_levels_p = [ i for i in range( lo, hi ) ]
|
||||||
else:
|
else:
|
||||||
# yuck
|
# yuck
|
||||||
p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||||
|
|
||||||
# input RVQ levels
|
# input RVQ levels
|
||||||
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
|
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
|
||||||
# trim resps to only contain all levels below the target level
|
# trim resps to only contain all levels below the target level
|
||||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
|
|
||||||
|
|
|
@ -11,5 +11,6 @@ from .utils import (
|
||||||
passes_policy,
|
passes_policy,
|
||||||
get_devices,
|
get_devices,
|
||||||
truncate_json,
|
truncate_json,
|
||||||
timer
|
timer,
|
||||||
|
prune_missing
|
||||||
)
|
)
|
|
@ -31,6 +31,24 @@ from datetime import datetime
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ):
|
||||||
|
is_obj = hasattr( source, "__dict__" )
|
||||||
|
if parent_is_obj is None:
|
||||||
|
parent_is_obj = is_obj
|
||||||
|
haystack = source.__dict__ if is_obj else source
|
||||||
|
keep = {}
|
||||||
|
missing = []
|
||||||
|
for k, v in dest.items():
|
||||||
|
if k in haystack or (parent_is_obj and not is_obj and source == {}):
|
||||||
|
keep[k] = dest[k]
|
||||||
|
else:
|
||||||
|
missing.append(".".join(path + [k]))
|
||||||
|
|
||||||
|
if recurse and isinstance( v, dict ):
|
||||||
|
keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing )
|
||||||
|
missing += m
|
||||||
|
return (keep, missing) if return_missing else keep
|
||||||
|
|
||||||
class timer:
|
class timer:
|
||||||
def __init__(self, msg="Elapsed time:", callback=None):
|
def __init__(self, msg="Elapsed time:", callback=None):
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
Loading…
Reference in New Issue
Block a user