From b3b67f34acb70e86013b5499e6667ced54537b63 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 13 Jun 2024 22:37:34 -0500 Subject: [PATCH] added option to sort paths by durations to better group equally lengthed sequences together (and there was maybe a logic error from creating the samplers and then interleave-reordering paths, desyncing them, maybe) --- vall_e/config.py | 1 + vall_e/data.py | 102 +++++++++++++++++++++++++++++------------------ 2 files changed, 65 insertions(+), 38 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index ee934b5..d639b02 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -149,6 +149,7 @@ class Dataset: p_resp_append: float = 1.0 sample_type: str = "path" # path | speaker + sample_order: str = "shuffle" # duration tasks_list: list[str] = field(default_factory=lambda: ["tts"]) diff --git a/vall_e/data.py b/vall_e/data.py index cd16a70..793b299 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -292,13 +292,11 @@ def _get_quant_path(path): def _get_phone_path(path): return _replace_file_extension(path, _get_phone_extension()) -_total_durations = {} - +_durations_map = {} +# makeshift caching the above to disk @cfg.diskcache() -def _calculate_durations( type="training" ): - if type in _total_durations: - return _total_durations[type] - return 0 +def _get_duration_map( type="training" ): + return _durations_map[type] if type in _durations_map else {} @cfg.diskcache() def _load_paths(dataset, type="training"): @@ -324,21 +322,19 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): def _validate( id, entry ): phones = entry['phones'] if "phones" in entry else 0 duration = entry['duration'] if "duration" in entry else 0 - if type not in _total_durations: - _total_durations[type] = 0 - - _total_durations[type] += duration - - """ - if cfg.dataset.use_hdf5: - k = key( id ) - if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]: - return False - """ - return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones + # add to duration bucket + k = key(id, entry) + if type not in _durations_map: + _durations_map[type] = {} + _durations_map[type][k] = duration - return [ key(id, entry) for id, entry in metadata.items() if not validate or _validate(id, entry) ] + if not validate: + return True + + return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration + + return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ] def _get_hdf5_path(path): @@ -348,17 +344,23 @@ def _get_hdf5_path(path): def _get_hdf5_paths( data_dir, type="training", validate=False ): data_dir = str(data_dir) + + key = f"/{type}/{_get_hdf5_path(data_dir)}" def _validate( id, entry ): phones = entry.attrs['phonemes'] duration = entry.attrs['duration'] - if type not in _total_durations: - _total_durations[type] = 0 - _total_durations[type] += entry.attrs['duration'] - return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones - key = f"/{type}/{_get_hdf5_path(data_dir)}" - return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if not validate or _validate(id, entry) ] if key in cfg.hdf5 else [] + if type not in _durations_map: + _durations_map[type] = {} + _durations_map[type][f"{key}/{id}"] = duration + + if not validate: + return True + + return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration + + return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else [] def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ): if isinstance(path, str): @@ -430,6 +432,7 @@ class Dataset(_Dataset): self.dataset_type = "training" if self.training else "validation" self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group" + self.sampler_order = cfg.dataset.sample_order # to-do: do not do validation if there's nothing in the validation # this just makes it be happy @@ -446,10 +449,11 @@ class Dataset(_Dataset): if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances: del self.paths_by_spkr_name[key] + # flatten paths self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) # split dataset accordingly per GPU - if self.sampler_type == "path" and cfg.distributed and self.training: + if cfg.distributed and self.training: batches = len(self.paths) // world_size() start = batches * global_rank() end = batches * (global_rank() + 1) @@ -464,6 +468,40 @@ class Dataset(_Dataset): self.paths_by_spkr_name[name] = [] self.paths_by_spkr_name[name].append( path ) + # do it here due to the above + self.duration = 0 + self.duration_map = _get_duration_map( self.dataset_type ) + self.duration_buckets = {} + + # store in corresponding bucket + for path in self.paths: + duration = self.duration_map[path] + self.duration += duration + + # only calc duration if we're tot going to order by duration + if self.sampler_order != "duration": + continue + + bucket = str(int(round(duration))) + if bucket not in self.duration_buckets: + self.duration_buckets[bucket] = [] + self.duration_buckets[bucket].append( ( Path(path), duration ) ) + + # sort by duration + if self.sampler_order == "duration": + # sort and interleave + for bucket in self.duration_buckets: + # sort by duration + self.duration_buckets[bucket].sort( key=lambda x: x[1] ) + # replace with path + self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ] + # flatten by paths + self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)] + # flatten paths + self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values())) + elif self.sampler_order == "shuffle": + # just interleave + self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() } @@ -485,9 +523,6 @@ class Dataset(_Dataset): self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() } - if self.sampler_type == "path": - self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] - self.noise_paths = _load_paths(cfg.dataset.noise, "noise") self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values())) @@ -504,15 +539,6 @@ class Dataset(_Dataset): if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") - #self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0 - self.duration = _calculate_durations(self.dataset_type) - - """ - @cached_property - def phones(self): - return sorted(set().union(*[_get_phones(path) for path in self.paths])) - """ - def get_speaker(self, path): if isinstance(path, str): path = Path(path)