added option to sort paths by durations to better group equally lengthed sequences together (and there was maybe a logic error from creating the samplers and then interleave-reordering paths, desyncing them, maybe)

This commit is contained in:
mrq 2024-06-13 22:37:34 -05:00
parent 83eab4fa59
commit b3b67f34ac
2 changed files with 65 additions and 38 deletions

View File

@ -149,6 +149,7 @@ class Dataset:
p_resp_append: float = 1.0
sample_type: str = "path" # path | speaker
sample_order: str = "shuffle" # duration
tasks_list: list[str] = field(default_factory=lambda: ["tts"])

View File

@ -292,13 +292,11 @@ def _get_quant_path(path):
def _get_phone_path(path):
return _replace_file_extension(path, _get_phone_extension())
_total_durations = {}
_durations_map = {}
# makeshift caching the above to disk
@cfg.diskcache()
def _calculate_durations( type="training" ):
if type in _total_durations:
return _total_durations[type]
return 0
def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
@cfg.diskcache()
def _load_paths(dataset, type="training"):
@ -324,21 +322,19 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
def _validate( id, entry ):
phones = entry['phones'] if "phones" in entry else 0
duration = entry['duration'] if "duration" in entry else 0
if type not in _total_durations:
_total_durations[type] = 0
_total_durations[type] += duration
"""
if cfg.dataset.use_hdf5:
k = key( id )
if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]:
return False
"""
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
# add to duration bucket
k = key(id, entry)
if type not in _durations_map:
_durations_map[type] = {}
_durations_map[type][k] = duration
return [ key(id, entry) for id, entry in metadata.items() if not validate or _validate(id, entry) ]
if not validate:
return True
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
def _get_hdf5_path(path):
@ -348,17 +344,23 @@ def _get_hdf5_path(path):
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
key = f"/{type}/{_get_hdf5_path(data_dir)}"
def _validate( id, entry ):
phones = entry.attrs['phonemes']
duration = entry.attrs['duration']
if type not in _total_durations:
_total_durations[type] = 0
_total_durations[type] += entry.attrs['duration']
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
key = f"/{type}/{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if not validate or _validate(id, entry) ] if key in cfg.hdf5 else []
if type not in _durations_map:
_durations_map[type] = {}
_durations_map[type][f"{key}/{id}"] = duration
if not validate:
return True
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
if isinstance(path, str):
@ -430,6 +432,7 @@ class Dataset(_Dataset):
self.dataset_type = "training" if self.training else "validation"
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
self.sampler_order = cfg.dataset.sample_order
# to-do: do not do validation if there's nothing in the validation
# this just makes it be happy
@ -446,10 +449,11 @@ class Dataset(_Dataset):
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
del self.paths_by_spkr_name[key]
# flatten paths
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
# split dataset accordingly per GPU
if self.sampler_type == "path" and cfg.distributed and self.training:
if cfg.distributed and self.training:
batches = len(self.paths) // world_size()
start = batches * global_rank()
end = batches * (global_rank() + 1)
@ -464,6 +468,40 @@ class Dataset(_Dataset):
self.paths_by_spkr_name[name] = []
self.paths_by_spkr_name[name].append( path )
# do it here due to the above
self.duration = 0
self.duration_map = _get_duration_map( self.dataset_type )
self.duration_buckets = {}
# store in corresponding bucket
for path in self.paths:
duration = self.duration_map[path]
self.duration += duration
# only calc duration if we're tot going to order by duration
if self.sampler_order != "duration":
continue
bucket = str(int(round(duration)))
if bucket not in self.duration_buckets:
self.duration_buckets[bucket] = []
self.duration_buckets[bucket].append( ( Path(path), duration ) )
# sort by duration
if self.sampler_order == "duration":
# sort and interleave
for bucket in self.duration_buckets:
# sort by duration
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
# replace with path
self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ]
# flatten by paths
self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)]
# flatten paths
self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values()))
elif self.sampler_order == "shuffle":
# just interleave
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
@ -485,9 +523,6 @@ class Dataset(_Dataset):
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
if self.sampler_type == "path":
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
@ -504,15 +539,6 @@ class Dataset(_Dataset):
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self.duration = _calculate_durations(self.dataset_type)
"""
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
"""
def get_speaker(self, path):
if isinstance(path, str):
path = Path(path)