From 75b90be325f780d8034587635d45f2e761e112ae Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 17 Oct 2024 17:06:48 -0500 Subject: [PATCH] cleaned up unused config flags, allow less strict yaml by pruning missing keys, renamed some dataset configs to be more unified --- vall_e/config.py | 140 ++++++++++++++++++--------------------- vall_e/data.py | 16 +++-- vall_e/engines/base.py | 3 - vall_e/models/ar.py | 14 ++-- vall_e/models/ar_nar.py | 14 ++-- vall_e/models/nar.py | 14 ++-- vall_e/utils/__init__.py | 3 +- vall_e/utils/utils.py | 18 +++++ 8 files changed, 113 insertions(+), 109 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index cfdef7e..48b1a19 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -21,15 +21,7 @@ from functools import cached_property from pathlib import Path from .utils.distributed import world_size - - -def set_seed(seed=None): - if not seed: - seed = time.time() - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) +from .utils import set_seed, prune_missing @dataclass() class BaseConfig: @@ -37,7 +29,7 @@ class BaseConfig: @property def cfg_path(self): - return Path(self.yaml_path.parent) if self.yaml_path is not None else None + return Path(self.yaml_path.parent) if self.yaml_path is not None else Path(__file__).parent.parent / "data" @property def rel_path(self): @@ -95,11 +87,24 @@ class BaseConfig: with open(path, "w") as f: f.write(self.dumps()) + # ick + @classmethod + def prune_missing( cls, yaml ): + default = cls(**{}) + default.format() + #default = json.loads(default.dumps()) + + yaml, missing = prune_missing( source=default, dest=yaml ) + if missing: + _logger.warning(f'Missing keys in YAML: {missing}') + return yaml + @classmethod def from_yaml( cls, yaml_path ): state = {} state = yaml.safe_load(open(yaml_path, "r", encoding="utf-8")) state.setdefault("yaml_path", yaml_path) + state = cls.prune_missing( state ) return cls(**state) @classmethod @@ -130,52 +135,48 @@ class Dataset: validation: list[Path] = field(default_factory=lambda: []) # paths to load into the validation dataset noise: list[Path] = field(default_factory=lambda: []) # paths to load into the noise dataset - temp: list[Path] = field(default_factory=lambda: []) # for when I need to yank things out of a training dataset without cutting it out - + # to-do: replace these since I feel this can be a bottleneck speaker_name_getter: str = "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" # function eval'd to extract a speaker's name from an utternace path speaker_group_getter: str = "lambda p: f'{p.parts[-3]}'" # function eval'd to extract a speaker's group from an utternace path - + # to-do: validate if I can ignore this since this is an artifact from when I only saved phonemes and encoded audio, and no metadata speaker_languages: dict = field(default_factory=lambda: {}) # dict where keys are the language codes and values are the speaker groups - hdf5_name: str = "data.h5" # file name to load the HDF5 dataset use_hdf5: bool = False # whether to load from an HDF5 dataset + hdf5_name: str = "data.h5" # file name to load the HDF5 dataset hdf5_flag: str = "a" # flag to load the HDF5 file, automatically adjusted anyways + use_metadata: bool = False # use genretaed metadata to aid in dataset loading validate: bool = True # validate each utterance on wheter it can be included based on duration range caps workers: int = 8 # number of dataloader workers to spawn cache: bool = True # use diskcache to cache the dataset - phones_range: list[int] = field(default_factory=lambda: [4, 256]) # deprecated, the amount of phonemes an utterance can be to be included in the dataset - duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset - prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be - - # to-do: clean up the following block, it's a mess min_utterances: int = 2 # minimum number of utterances a speaker can have - random_utterance: float = 1.0 # probability to use a different utterance rather than using the target utterance as an input prompt - max_prompts: int = 3 # maximum number of utterances that can be included in an input prompt for training - prompt_duration: float | None = None # legacy - max_resps: int = 1 # number of samples to target for training - p_resp_append: float = 1.0 # probability to append another sample to the training target - p_resp_pad_silence: float = 0.0 # probability to pad resp with silence to fit within the next window - prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt - prompt_similar_top_k: int = 1 # top-k similar candidates to sample from - prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from - + duration_range: list[float] = field(default_factory=lambda: [1.0, 12.0]) # the duration range an utterance can be to be included in the dataset + sample_type: str = "path" # path | speaker sample_order: str = "interleaved" # duration + sample_shuffle: bool = True # shuffles the indices in the sampler sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable # for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough # for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size) - sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0 - prom_sample_similar: bool = True # if available, try and sample the prompt closest to the sampled response utterance (requires specific metadata generated) + prompt_duration_range: list[float] = field(default_factory=lambda: [3.0, 6.0]) # the duration range the input prompts can be + prompt_max_samples: int = 3 # maximum number of utterances that can be included in an input prompt for training + prompt_continuous_utterance_p: float = 0.0 # probability to use the target utterance as an input prompt rather than using a different utterance + prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt + prompt_similar_top_k: int = 1 # top-k similar candidates to sample from + prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from + + resps_max_samples: int = 1 # number of samples to target for training + resps_append_p: float = 1.0 # probability to append another sample to the training target + resps_pad_silence_p: float = 0.0 # probability to pad resp with silence to fit within the next window tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method noise_scale: float = 0.25 # scaling noise value - inject_noise_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things + noise_inject_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things retokenize_text: bool = False _frames_per_second: int = 0 # allows setting your own hint @@ -219,7 +220,7 @@ class ModelExperimentalSettings: # in theory a model that is trained to sum embeddings can peform better due to "seeing" previous levles (due to the R in RVQ standing for residuals...), but in practice it seems fine to not do so audio_embedding_mode: str | None = None # None | "exclusive" | "inclusive", subjugates the audio backend's encoding/decoding model for embeddings kv_heads: int = 0 # MHA or GQA (for supported backends) - p_rvq_levels: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely + rvq_levels_p: str | list = "auto" # determines odds of selecting RVQ levels when training, "equal" will make each level equally likely rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary unified_position_ids: bool = True # False will generate position IDs partitioned for each section tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing @@ -431,8 +432,6 @@ class Evaluation: nar_temperature: float = 0.0 # NAR temp for inferencing nar_levels: int = 0 # maximum NAR levels to use for inferencing - load_disabled_engines: bool = True # see the other load_disabled_engines - @dataclass() class DeepSpeed: zero_optimization_level: int = 0 # doesn't seem to work @@ -617,10 +616,8 @@ class Trainer: activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training - aggressive_optimizations: bool = False # deprecated check_for_oom: bool = True # checks for OOMs thrown during forward/backwards gc_mode: str | None = None # deprecated, but marks when to do GC - load_disabled_engines: bool = False # deprecated, but signals to load engines not used for training for, for example, evaluation/validation weight_dtype: str = "float16" # dtype to have the model under @@ -628,8 +625,7 @@ class Trainer: ddp: bool = False # torch's internal DDP, automatically set if local backend is used and multiple GPUs are requested #scale_loss: bool = False # whether to perform loss scaling (for FP16 training) (it actually seems more harmful than not for this specific workload) - load_webui: bool = False # not working, but loads the web UI to allow inferencing during training - no_logger: bool = False # deprecated, but reroutes some logger calls to normal print statements for when logger broke because of BitNet + load_webui: bool = False # load the web UI to allow inferencing during training, to-do: actually make this work backend: str = "local" # training backend to use. currently supports "local" | "deepspeed" deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) # deepspeed settings @@ -657,13 +653,7 @@ class Inference: weight_dtype: str = "float32" # dtype to load the model under amp: bool = False # automatic mixed precision during inferencing - normalize: bool = False # do NOT enable this unless you know exactly what you're doing - - # legacy / backwards compat - audio_backend: str = "" # encodec, vocos, dac - use_vocos: bool = True - use_encodec: bool = True - use_dac: bool = True + normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though @property def dtype(self): @@ -694,6 +684,7 @@ class Optimizations: bitnet: bool = False # use bitnet fp8: bool = False # use fp8 + # to-do: validate this madness works still, I don't remember what schizodemon told me to do this model_offloading: dict | None = None # automatically splits the model over a list of devices # example: {"include":["model"], "limits": [ (6 * 1024) * (1024 ** 2), -1 ]} will have the GPU capped to 6GiB, and offload the remaining layers to CPU # example: {"include":["model"], "device": ["cuda:0", "cuda:1"], "limits": [ 0.5, 0.5 ]} will have the GPU 1 try and use 50% of the model, and GPU 2 try and use the other 50% @@ -705,7 +696,7 @@ class Optimizations: class Config(BaseConfig): device: str = "cuda" # target device mode: str = "training" # "inferencing" - experimental: bool = False # Debug flag, unused now + experimental: bool = False # debug flag dataset: Dataset = field(default_factory=lambda: Dataset) models: dict | list | None = field(default_factory=lambda: []) @@ -714,7 +705,6 @@ class Config(BaseConfig): evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) inference: Inference = field(default_factory=lambda: Inference) - bitsandbytes: dict | list | None = None # deprecated optimizations: Optimizations = field(default_factory=lambda: Optimizations) tokenizer: str | None = None # tokenizer class @@ -828,7 +818,6 @@ class Config(BaseConfig): return path - # to-do: prune unused keys def format( self, training=True ): if isinstance(self.dataset, type): self.dataset = dict() @@ -869,19 +858,25 @@ class Config(BaseConfig): if not isinstance( model, dict ): continue - if "prom_levels" in model: - del model["prom_levels"] - - if "interleave" in model: - del model["interleave"] - - if "audio_embedding_sums" not in model: - continue - if "experimental" not in model or not model["experimental"]: model["experimental"] = {} - model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums") + if "prom_levels" in model: + _logger.warning(f"Deprecated flag found: {'cfg.model.prom_levels'}") + del model["prom_levels"] + + if "interleave" in model: + _logger.warning(f"Deprecated flag found: {'cfg.model.interleave'}") + del model["interleave"] + + if "p_rvq_levels" in model["experimental"] and "rvq_levels_p" not in model["experimental"]: + _logger.warning(f"Deprecated flag found: {'cfg.model.experimental.p_rvq_levels'}") + model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"] + del model["experimental"]["p_rvq_levels"] + + if "audio_embedding_sums" in model: + _logger.warning(f"Deprecated flag found: {'cfg.model.p_rvq_levels'}") + model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums") self.models = [ Model(**model) for model in self.models ] @@ -905,11 +900,7 @@ class Config(BaseConfig): self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed) self.inference = Inference(**self.inference) - - if self.bitsandbytes is not None: - self.optimizations = Optimizations(**self.bitsandbytes) - else: - self.optimizations = Optimizations(**self.optimizations) + self.optimizations = Optimizations(**self.optimizations) if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler: self.hyperparameters.scheduler = self.hyperparameters.scheduler_type @@ -922,15 +913,9 @@ class Config(BaseConfig): if self.hyperparameters.scheduler == "": self.hyperparameters.torch_scheduler = True - if self.dataset.prompt_duration is not None: - self.dataset.prompt_duration_range = [self.dataset.prompt_duration, self.dataset.prompt_duration] - if self.trainer.backend == "local" and self.distributed: self.trainer.ddp = True - if self.inference.audio_backend != "" and self.audio_backend == "": - self.audio_backend = self.inference.audio_backend - if self.trainer.activation_checkpointing is not None: self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing @@ -942,22 +927,23 @@ class Config(BaseConfig): self.load_hdf5() # load tokenizer - if cfg.tokenizer == "naive": - cfg.tokenizer = NaiveTokenizer() + if self.tokenizer == "naive": + self.tokenizer = NaiveTokenizer() else: + # ick... try: from transformers import PreTrainedTokenizerFast - tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None + tokenizer_path = self.rel_path / self.tokenizer_path if self.yaml_path is not None else None if tokenizer_path and not tokenizer_path.exists(): - tokenizer_path = Path("./data/") / cfg.tokenizer_path + tokenizer_path = Path("./data/") / self.tokenizer_path if tokenizer_path and tokenizer_path.exists(): - cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) else: - cfg.tokenizer = NaiveTokenizer() + self.tokenizer = NaiveTokenizer() except Exception as e: - cfg.tokenizer = NaiveTokenizer() + self.tokenizer = NaiveTokenizer() _logger.warning(f"Error while parsing tokenizer: {str(e)}") pass diff --git a/vall_e/data.py b/vall_e/data.py index 67a8c55..86e0ae4 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1012,10 +1012,12 @@ class Dataset(_Dataset): prom_length = 0 trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0 - for _ in range(cfg.dataset.max_prompts): - if reference is not None and cfg.dataset.prom_sample_similar: - path = self.get_similar_utterance( reference, offset = len(prom_list) ) if random.random() < cfg.dataset.prompt_similar_p else random.choice(choices) + for _ in range(cfg.dataset.prompt_max_samples): + if reference is not None: # yuck + path = None + if random.random() < cfg.dataset.prompt_similar_p: + path = self.get_similar_utterance( reference, offset = len(prom_list) ) if not path: path = random.choice(choices) else: @@ -1032,7 +1034,7 @@ class Dataset(_Dataset): prom_list.append(qnt) prom_length += qnt.shape[0] - if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance: + if prom_length >= trim_length: break # might be better to decode => concat waveforms with silence in between => reencode @@ -1113,9 +1115,9 @@ class Dataset(_Dataset): naive = cfg.experimental # append additional prompts in an attempt to artifically increase lengths / offer new data - if cfg.dataset.max_resps > 1 and random.random() < cfg.dataset.p_resp_append: + if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p: ignore_paths = [] - for _ in range( 1, cfg.dataset.max_resps ): + for _ in range( 1, cfg.dataset.resps_max_samples ): path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths) ignore_paths.append(path) @@ -1316,7 +1318,7 @@ class Dataset(_Dataset): text = torch.tensor([bos_id, eos_id]).to(self.text_dtype) # pad the target with silence - if random.random() < cfg.dataset.p_resp_pad_silence: + if random.random() < cfg.dataset.resps_pad_silence_p: resps = pad_codes_with_silence( resps ) return dict( diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 90cc140..b17f47c 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -525,9 +525,6 @@ class Engines(dict[str, Engine]): n_ooms = torch.zeros([], device=device) - if cfg.trainer.aggressive_optimizations: - batch = to_device(batch, 'cpu') - if not cfg.trainer.check_for_oom: engine.backward(loss) else: diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index b5b9dbb..95bd435 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -79,7 +79,7 @@ class AR(Base): # is training if training: # specifies how to sample probabilities of which RVQ levels to train against - p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" + rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" # determines which RVQ level to target per batch quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] # rate to perform token dropout errors @@ -90,19 +90,19 @@ class AR(Base): if not token_dropout_rvq_levels: token_dropout_rvq_levels = [0, self.resp_levels - 1] # allow passing a specific distribution of RVQ levels - p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else [] - if not p_rvq_levels: + rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] + if not rvq_levels_p: lo, hi = quant_level_range[0], quant_level_range[1] + 1 # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - if p_rvq_levels == "equal": - p_rvq_levels = [ i for i in range( lo, hi ) ] + if rvq_levels_p == "equal": + rvq_levels_p = [ i for i in range( lo, hi ) ] else: # yuck - p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) # input RVQ levels if not self.interleave: - quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] + quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] # trim resps to only contain all levels below the target level resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] else: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 29e08fc..84c3ee4 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -102,7 +102,7 @@ class AR_NAR(Base): # is training if training: # specifies how to sample probabilities of which RVQ levels to train against - p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" + rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" # determines which RVQ level to target per batch quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] # rate to perform token dropout errors @@ -113,18 +113,18 @@ class AR_NAR(Base): if not token_dropout_rvq_levels: token_dropout_rvq_levels = [0, self.resp_levels - 1] # allow passing a specific distribution of RVQ levels - p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else [] - if not p_rvq_levels: + rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] + if not rvq_levels_p: lo, hi = quant_level_range[0], quant_level_range[1] + 1 # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - if p_rvq_levels == "equal": - p_rvq_levels = [ i for i in range( lo, hi ) ] + if rvq_levels_p == "equal": + rvq_levels_p = [ i for i in range( lo, hi ) ] else: # yuck - p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) # input RVQ levels - quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] + quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] for i, task in enumerate( task_list ): if task in text_task: quant_levels[i] = 0 # self.n_resp_levels - 1 diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 74cbf76..b0be902 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -75,7 +75,7 @@ class NAR(Base): task_list = [ sample_task() for _ in range(batch_size) ] # specifies how to sample probabilities of which RVQ levels to train against - p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" + rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" # determines which RVQ level to target per batch quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] # rate to perform token dropout errors @@ -86,18 +86,18 @@ class NAR(Base): if not token_dropout_rvq_levels: token_dropout_rvq_levels = [0, self.resp_levels - 1] # allow passing a specific distribution of RVQ levels - p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else [] - if not p_rvq_levels: + rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] + if not rvq_levels_p: lo, hi = quant_level_range[0], quant_level_range[1] + 1 # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - if p_rvq_levels == "equal": - p_rvq_levels = [ i for i in range( lo, hi ) ] + if rvq_levels_p == "equal": + rvq_levels_p = [ i for i in range( lo, hi ) ] else: # yuck - p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) # input RVQ levels - quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] + quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] # trim resps to only contain all levels below the target level resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index c941932..bcc8c44 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -11,5 +11,6 @@ from .utils import ( passes_policy, get_devices, truncate_json, - timer + timer, + prune_missing ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index b487c46..08d6912 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -31,6 +31,24 @@ from datetime import datetime T = TypeVar("T") +def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True ): + is_obj = hasattr( source, "__dict__" ) + if parent_is_obj is None: + parent_is_obj = is_obj + haystack = source.__dict__ if is_obj else source + keep = {} + missing = [] + for k, v in dest.items(): + if k in haystack or (parent_is_obj and not is_obj and source == {}): + keep[k] = dest[k] + else: + missing.append(".".join(path + [k])) + + if recurse and isinstance( v, dict ): + keep[k], m = prune_missing( haystack[k], dest[k], path=path + [k], parent_is_obj=parent_is_obj, return_missing=return_missing ) + missing += m + return (keep, missing) if return_missing else keep + class timer: def __init__(self, msg="Elapsed time:", callback=None): self.msg = msg