added option to sort paths by durations to better group equally lengthed sequences together (and there was maybe a logic error from creating the samplers and then interleave-reordering paths, desyncing them, maybe)
This commit is contained in:
parent
83eab4fa59
commit
b3b67f34ac
|
@ -149,6 +149,7 @@ class Dataset:
|
|||
p_resp_append: float = 1.0
|
||||
|
||||
sample_type: str = "path" # path | speaker
|
||||
sample_order: str = "shuffle" # duration
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||
|
||||
|
|
102
vall_e/data.py
102
vall_e/data.py
|
@ -292,13 +292,11 @@ def _get_quant_path(path):
|
|||
def _get_phone_path(path):
|
||||
return _replace_file_extension(path, _get_phone_extension())
|
||||
|
||||
_total_durations = {}
|
||||
|
||||
_durations_map = {}
|
||||
# makeshift caching the above to disk
|
||||
@cfg.diskcache()
|
||||
def _calculate_durations( type="training" ):
|
||||
if type in _total_durations:
|
||||
return _total_durations[type]
|
||||
return 0
|
||||
def _get_duration_map( type="training" ):
|
||||
return _durations_map[type] if type in _durations_map else {}
|
||||
|
||||
@cfg.diskcache()
|
||||
def _load_paths(dataset, type="training"):
|
||||
|
@ -324,21 +322,19 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
def _validate( id, entry ):
|
||||
phones = entry['phones'] if "phones" in entry else 0
|
||||
duration = entry['duration'] if "duration" in entry else 0
|
||||
if type not in _total_durations:
|
||||
_total_durations[type] = 0
|
||||
|
||||
_total_durations[type] += duration
|
||||
|
||||
"""
|
||||
if cfg.dataset.use_hdf5:
|
||||
k = key( id )
|
||||
if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]:
|
||||
return False
|
||||
"""
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
# add to duration bucket
|
||||
k = key(id, entry)
|
||||
if type not in _durations_map:
|
||||
_durations_map[type] = {}
|
||||
_durations_map[type][k] = duration
|
||||
|
||||
return [ key(id, entry) for id, entry in metadata.items() if not validate or _validate(id, entry) ]
|
||||
if not validate:
|
||||
return True
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
|
||||
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
||||
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
|
@ -348,17 +344,23 @@ def _get_hdf5_path(path):
|
|||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||
|
||||
def _validate( id, entry ):
|
||||
phones = entry.attrs['phonemes']
|
||||
duration = entry.attrs['duration']
|
||||
if type not in _total_durations:
|
||||
_total_durations[type] = 0
|
||||
_total_durations[type] += entry.attrs['duration']
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||
|
||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if not validate or _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||
if type not in _durations_map:
|
||||
_durations_map[type] = {}
|
||||
_durations_map[type][f"{key}/{id}"] = duration
|
||||
|
||||
if not validate:
|
||||
return True
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
|
||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
||||
if isinstance(path, str):
|
||||
|
@ -430,6 +432,7 @@ class Dataset(_Dataset):
|
|||
self.dataset_type = "training" if self.training else "validation"
|
||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
||||
self.sampler_order = cfg.dataset.sample_order
|
||||
|
||||
# to-do: do not do validation if there's nothing in the validation
|
||||
# this just makes it be happy
|
||||
|
@ -446,10 +449,11 @@ class Dataset(_Dataset):
|
|||
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
||||
del self.paths_by_spkr_name[key]
|
||||
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||
|
||||
# split dataset accordingly per GPU
|
||||
if self.sampler_type == "path" and cfg.distributed and self.training:
|
||||
if cfg.distributed and self.training:
|
||||
batches = len(self.paths) // world_size()
|
||||
start = batches * global_rank()
|
||||
end = batches * (global_rank() + 1)
|
||||
|
@ -464,6 +468,40 @@ class Dataset(_Dataset):
|
|||
self.paths_by_spkr_name[name] = []
|
||||
self.paths_by_spkr_name[name].append( path )
|
||||
|
||||
# do it here due to the above
|
||||
self.duration = 0
|
||||
self.duration_map = _get_duration_map( self.dataset_type )
|
||||
self.duration_buckets = {}
|
||||
|
||||
# store in corresponding bucket
|
||||
for path in self.paths:
|
||||
duration = self.duration_map[path]
|
||||
self.duration += duration
|
||||
|
||||
# only calc duration if we're tot going to order by duration
|
||||
if self.sampler_order != "duration":
|
||||
continue
|
||||
|
||||
bucket = str(int(round(duration)))
|
||||
if bucket not in self.duration_buckets:
|
||||
self.duration_buckets[bucket] = []
|
||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||
|
||||
# sort by duration
|
||||
if self.sampler_order == "duration":
|
||||
# sort and interleave
|
||||
for bucket in self.duration_buckets:
|
||||
# sort by duration
|
||||
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
||||
# replace with path
|
||||
self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ]
|
||||
# flatten by paths
|
||||
self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)]
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values()))
|
||||
elif self.sampler_order == "shuffle":
|
||||
# just interleave
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
|
||||
|
@ -485,9 +523,6 @@ class Dataset(_Dataset):
|
|||
|
||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
if self.sampler_type == "path":
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
||||
|
||||
|
@ -504,15 +539,6 @@ class Dataset(_Dataset):
|
|||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
||||
self.duration = _calculate_durations(self.dataset_type)
|
||||
|
||||
"""
|
||||
@cached_property
|
||||
def phones(self):
|
||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||
"""
|
||||
|
||||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
|
Loading…
Reference in New Issue
Block a user