From f7b8b1e825507a389a15f86d066971d4d4d7cc35 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 11 Nov 2024 17:00:49 -0600 Subject: [PATCH] dropped subtrain dataloader since its useless to duplicate --- docs/data.md | 9 ++++++-- docs/emb.md | 2 +- vall_e/data.py | 47 ++++++++++++++++++----------------------- vall_e/train.py | 14 +++++++++--- vall_e/utils/sampler.py | 34 +++++++++++++++++++++-------- 5 files changed, 65 insertions(+), 41 deletions(-) diff --git a/docs/data.md b/docs/data.md index 466018e..c9474c8 100644 --- a/docs/data.md +++ b/docs/data.md @@ -61,7 +61,12 @@ The dataloader handles some simple yet effective features, such as: * picking an input prompt from the same speaker as the sample, if the above is not requested * preparing the input sequence for the given task (such as non-TTS tasks) -The initial list of paths is cached through `diskcache`, if `cfg.dataset.cache == True`. Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions. +If `cfg.dataset.cache == True`, the initial list of paths and duration metadata (used for sorting/bucketing) is cached ~~through `diskcache`~~ under `{YAML_PATH}/.cache/{DATASET_HASH}/`. To allow for seamless modifications to the loaded dataset, the `DATASET_HASH` relies on: +* duration range +* folders/groups in the dataset +* if using HDF5 (due to the key format differing) + +Be sure to delete the resultant `.cache` folder, as well as the `sampler.*` state dicts alongside checkpoints, if you plan to modify the dataloader settings between training sessions. ## Tasks @@ -95,7 +100,7 @@ This section may be covered elsewhere in the documentation, but coverage here sh * `nse`: noisy speech editing. * the above, but injects some noise throughout the sampled utterances. -A mystical `rvc` for emulating RVC speech-to-speech synthesis is possible, but requires a dataset to do so. +A mystical `vc` for performing voice conversion is possible, but either requires a dataset to do so, or abusing an emergent property. ## `__main__` diff --git a/docs/emb.md b/docs/emb.md index eda4e9a..a27a0ce 100644 --- a/docs/emb.md +++ b/docs/emb.md @@ -60,7 +60,7 @@ Descript-Audio-Codec was thoroughly tested for promising much, much cleaner outp However, due to the nature of the codec, simply throwing it at an attention-based transformer proves to be painful, as a unified AR+NAR model *heavily* suffers from noisy output in the NAR. -Ironically, testing through mal-encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances. +Ironically, testing through erroneously encoded audio (feeding 24KHz audio without upsampling to 44.1KHz) proved to have "cleaner" but bad utterances. I'm uncertain on how to remedy this, as my options are: * train under a RetNet, if an attention-based transformer is simply the problem diff --git a/vall_e/data.py b/vall_e/data.py index eb501f5..b9ded90 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -534,8 +534,11 @@ _durations_map = {} def _get_duration_map( type="training" ): return _durations_map[type] if type in _durations_map else {} -def _load_paths(dataset, type="training", silent=False): - cached_dir = cfg.cache_dir / cfg.dataset.hash_key(sorted(dataset)) +def _load_paths(dataset, type="training", silent=False, dataset_hash_key=None): + if not dataset_hash_key: + dataset_hash_key = cfg.dataset.hash_key(sorted(dataset)) + + cached_dir = cfg.cache_dir / dataset_hash_key cached_durations_path = cached_dir / f"durations[{type}].json" cached_paths_path = cached_dir / f"dataloader[{type}].json" @@ -691,11 +694,13 @@ class Dataset(_Dataset): self.training = training self.dataset_type = "training" if self.training else "validation" - self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation + self.dataset = sorted(cfg.dataset.training if self.training else cfg.dataset.validation) self.sampler_type = cfg.dataset.sample_type if self.dataset_type == "training" else "path" self.sampler_order = cfg.dataset.sample_order self.sampler_shuffle = cfg.dataset.sample_shuffle if self.dataset_type == "training" else True + self.dataset_hash_key = cfg.dataset.hash_key(sorted(self.dataset)) + # to-do: do not do validation if there's nothing in the validation # this just makes it be happy if len(self.dataset) == 0: @@ -706,7 +711,7 @@ class Dataset(_Dataset): raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.') # dict of paths keyed by speaker names - self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type) + self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key) # do it here due to the above self.duration = 0 self.duration_map = _get_duration_map( self.dataset_type ) @@ -779,10 +784,6 @@ class Dataset(_Dataset): else: # just interleave self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] - - # dereference buckets - self.duration_map = None - self.duration_buckets = None # dict of speakers keyed by speaker group self.spkrs_by_spkr_group = {} @@ -830,16 +831,21 @@ class Dataset(_Dataset): self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways max_duration=cfg.dataset.sample_max_duration_batch, max_batch_size=cfg.hyperparameters.batch_size if self.training else cfg.evaluation.batch_size, - shuffle=self.sampler_shuffle + shuffle=self.sampler_shuffle, + dataset_hash=self.dataset_hash_key, ) else: - self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) ) + self.sampler = OrderedSampler( len(self), dataset_hash=self.dataset_hash_key ) if not self.sampler_shuffle else RandomSampler( len(self), dataset_hash=self.dataset_hash_key ) self.samplers = {} self.spkr_samplers = {} else: - self.sampler = RandomSampler( len(self) ) - self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() } - self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() } + self.sampler = RandomSampler( len(self), dataset_hash=self.dataset_hash_key ) + self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, paths in self.paths_by_spkr_name.items() } + self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle, dataset_hash=self.dataset_hash_key ) for name, speakers in self.spkrs_by_spkr_group.items() } + + # dereference buckets + self.duration_map = None + self.duration_buckets = None self.load_state_dict() @@ -903,6 +909,7 @@ class Dataset(_Dataset): return state_dict = torch_load(path) + if self.sampler_type == "path": state_dict = self.sampler.set_state(state_dict) else: @@ -1461,16 +1468,8 @@ def create_val_dataloader(): # to-do, use the above two, then create the subtrain dataset def create_train_val_dataloader(): train_dataset, val_dataset = create_datasets() - - # deepcopy is slow - subtrain_dataset = Dataset( training=True ) - - if subtrain_dataset.sampler_type == "path": - subtrain_dataset.head_(cfg.evaluation.size) - train_dl = _create_dataloader(train_dataset, training=True) val_dl = _create_dataloader(val_dataset, training=False) - subtrain_dl = _create_dataloader(subtrain_dataset, training=False) _logger.info(str(train_dataset.phone_symmap)) _logger.info(str(train_dataset.spkr_symmap)) @@ -1478,18 +1477,14 @@ def create_train_val_dataloader(): _logger.info(f"#samples (train): {len(train_dataset)}.") _logger.info(f"#samples (val): {len(val_dataset)}.") - _logger.info(f"#samples (subtrain): {len(subtrain_dataset)}.") _logger.info(f"#duration (train): {str(train_dataset.duration)}.") _logger.info(f"#duration (val): {str(val_dataset.duration)}.") - _logger.info(f"#duration (subtrain): {str(subtrain_dataset.duration)}.") - - assert isinstance(subtrain_dl.dataset, Dataset) # remove duration map (it gets bloated) _durations_map = {} - return train_dl, subtrain_dl, val_dl + return train_dl, val_dl # parse metadata from an numpy file (.enc/.dac) and validate it def process_artifact_metadata( artifact ): diff --git a/vall_e/train.py b/vall_e/train.py index 69b9307..2082dac 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -108,7 +108,15 @@ def run_eval(engines, eval_name, dl, args=None): processed = 0 while processed < cfg.evaluation.size: - batch = to_device(next(iter(dl)), cfg.device) + # directly randomly sample + if eval_name == "subtrain": + # sample from dataset + # to-do: derive from current iteration + samples = [ to_device(dl.dataset[random.randint( 0, len( dl.dataset ) )], cfg.device) for sample in range( cfg.evaluation.batch_size ) ] + # collate manually + batch = {k: [s[k] for s in samples] for k in samples[0]} + else: + batch = to_device(next(iter(dl)), cfg.device) # limit to eval batch size in the event we somehow have a weird dataloader for key in batch.keys(): @@ -209,14 +217,14 @@ def train(): if cfg.yaml_path is not None and is_global_leader(): shutil.copy( cfg.yaml_path, cfg.log_dir / "config.yaml" ) # create dataloaders - train_dl, subtrain_dl, val_dl = create_train_val_dataloader() + train_dl, val_dl = create_train_val_dataloader() # evaluation lambda def eval_fn(engines): do_gc() engines.eval() # wrapped in a try block because it's sometimes prone to breaking try: - run_eval(engines, "subtrain", subtrain_dl, args) + run_eval(engines, "subtrain", train_dl, args) run_eval(engines, "val", val_dl, args) except Exception as e: _logger.warning(f"Error occurred while performing eval: {str(e)}") diff --git a/vall_e/utils/sampler.py b/vall_e/utils/sampler.py index d5b6b76..2e072ec 100644 --- a/vall_e/utils/sampler.py +++ b/vall_e/utils/sampler.py @@ -9,11 +9,12 @@ from .distributed import global_rank, local_rank, world_size # Randomly picks an index from an array of indices class PoolSampler(): - def __init__( self, pool = [], keep_all = False, shuffle = False ): + def __init__( self, pool = [], keep_all = False, shuffle = False, dataset_hash = None ): self.length = len(pool) self.shuffle = shuffle self.global_pool = pool if keep_all else None self.global_indices = [ i for i in range(self.length) ] + self.dataset_hash = dataset_hash self.reset() def reset(self): @@ -45,21 +46,25 @@ class PoolSampler(): return self.sample(*args, **kwargs) def get_state(self): - return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool } + return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool, "dataset_hash": self.dataset_hash } def set_state(self, state): self.length = state["length"] self.global_pool = state["global_pool"] self.global_indices = state["global_indices"] self.current_pool = state["current_pool"] + # could .pop() + if "dataset_hash" in state: + self.dataset_hash = state["dataset_hash"] # "Samples" through a fixed sequence from 0 to length # Necessary for our "shuffle+sort by duration+interleave" sampling method # Allows saving and loading state class OrderedSampler(Sampler): - def __init__( self, length ): + def __init__( self, length, dataset_hash=None ): self.position = 0 self.length = length + self.dataset_hash = dataset_hash def __len__(self): return self.length @@ -73,18 +78,22 @@ class OrderedSampler(Sampler): self.position += 1 def get_state(self): - return { "position": self.position, "length": self.length } + return { "position": self.position, "length": self.length, "dataset_hash": self.dataset_hash } def set_state(self, state): self.position = state["position"] self.length = state["length"] + # could .pop() + if "dataset_hash" in state: + self.dataset_hash = state["dataset_hash"] # Like the above, but will batch based on token count class BatchedOrderedSampler(Sampler): - def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False ): + def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, dataset_hash=None ): self.position = 0 self.batches = [] self.shuffle = shuffle + self.dataset_hash = dataset_hash assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" @@ -126,18 +135,22 @@ class BatchedOrderedSampler(Sampler): self.position += 1 def get_state(self): - return { "position": self.position, "batches": self.batches } + return { "position": self.position, "batches": self.batches, "dataset_hash": self.dataset_hash } def set_state(self, state): self.position = state["position"] self.batches = state["batches"] + # could .pop() + if "dataset_hash" in state: + self.dataset_hash = state["dataset_hash"] # Randomly samples indices from a given sequence from 0 to length # Allows saving and loading state class RandomSampler(Sampler): - def __init__( self, length ): + def __init__( self, length, dataset_hash=None ): self.position = 0 self.length = length + self.dataset_hash = dataset_hash self.generator = torch.Generator() self.perm = torch.randperm(self.length, generator=self.generator) @@ -155,10 +168,13 @@ class RandomSampler(Sampler): self.position += 1 def get_state(self): - return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() } + return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state(), "dataset_hash": self.dataset_hash } def set_state(self, state): self.position = state["position"] self.length = state["length"] self.perm = state["perm"] - self.generator.set_state(state["generator"]) \ No newline at end of file + self.generator.set_state(state["generator"]) + # could .pop() + if "dataset_hash" in state: + self.dataset_hash = state["dataset_hash"] \ No newline at end of file