added option to sort paths by durations to better group equally lengthed sequences together (and there was maybe a logic error from creating the samplers and then interleave-reordering paths, desyncing them, maybe)
This commit is contained in:
parent
83eab4fa59
commit
b3b67f34ac
|
@ -149,6 +149,7 @@ class Dataset:
|
||||||
p_resp_append: float = 1.0
|
p_resp_append: float = 1.0
|
||||||
|
|
||||||
sample_type: str = "path" # path | speaker
|
sample_type: str = "path" # path | speaker
|
||||||
|
sample_order: str = "shuffle" # duration
|
||||||
|
|
||||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
tasks_list: list[str] = field(default_factory=lambda: ["tts"])
|
||||||
|
|
||||||
|
|
102
vall_e/data.py
102
vall_e/data.py
|
@ -292,13 +292,11 @@ def _get_quant_path(path):
|
||||||
def _get_phone_path(path):
|
def _get_phone_path(path):
|
||||||
return _replace_file_extension(path, _get_phone_extension())
|
return _replace_file_extension(path, _get_phone_extension())
|
||||||
|
|
||||||
_total_durations = {}
|
_durations_map = {}
|
||||||
|
# makeshift caching the above to disk
|
||||||
@cfg.diskcache()
|
@cfg.diskcache()
|
||||||
def _calculate_durations( type="training" ):
|
def _get_duration_map( type="training" ):
|
||||||
if type in _total_durations:
|
return _durations_map[type] if type in _durations_map else {}
|
||||||
return _total_durations[type]
|
|
||||||
return 0
|
|
||||||
|
|
||||||
@cfg.diskcache()
|
@cfg.diskcache()
|
||||||
def _load_paths(dataset, type="training"):
|
def _load_paths(dataset, type="training"):
|
||||||
|
@ -324,21 +322,19 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
def _validate( id, entry ):
|
def _validate( id, entry ):
|
||||||
phones = entry['phones'] if "phones" in entry else 0
|
phones = entry['phones'] if "phones" in entry else 0
|
||||||
duration = entry['duration'] if "duration" in entry else 0
|
duration = entry['duration'] if "duration" in entry else 0
|
||||||
if type not in _total_durations:
|
|
||||||
_total_durations[type] = 0
|
|
||||||
|
|
||||||
_total_durations[type] += duration
|
|
||||||
|
|
||||||
"""
|
|
||||||
if cfg.dataset.use_hdf5:
|
|
||||||
k = key( id )
|
|
||||||
if k not in cfg.hdf5 or "audio" not in cfg.hdf5[k] or "text" not in cfg.hdf5[k]:
|
|
||||||
return False
|
|
||||||
"""
|
|
||||||
|
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
# add to duration bucket
|
||||||
|
k = key(id, entry)
|
||||||
|
if type not in _durations_map:
|
||||||
|
_durations_map[type] = {}
|
||||||
|
_durations_map[type][k] = duration
|
||||||
|
|
||||||
return [ key(id, entry) for id, entry in metadata.items() if not validate or _validate(id, entry) ]
|
if not validate:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||||
|
|
||||||
|
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
||||||
|
|
||||||
|
|
||||||
def _get_hdf5_path(path):
|
def _get_hdf5_path(path):
|
||||||
|
@ -348,17 +344,23 @@ def _get_hdf5_path(path):
|
||||||
|
|
||||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||||
data_dir = str(data_dir)
|
data_dir = str(data_dir)
|
||||||
|
|
||||||
|
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||||
|
|
||||||
def _validate( id, entry ):
|
def _validate( id, entry ):
|
||||||
phones = entry.attrs['phonemes']
|
phones = entry.attrs['phonemes']
|
||||||
duration = entry.attrs['duration']
|
duration = entry.attrs['duration']
|
||||||
if type not in _total_durations:
|
|
||||||
_total_durations[type] = 0
|
|
||||||
_total_durations[type] += entry.attrs['duration']
|
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration #and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
|
||||||
|
|
||||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
if type not in _durations_map:
|
||||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if not validate or _validate(id, entry) ] if key in cfg.hdf5 else []
|
_durations_map[type] = {}
|
||||||
|
_durations_map[type][f"{key}/{id}"] = duration
|
||||||
|
|
||||||
|
if not validate:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||||
|
|
||||||
|
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||||
|
|
||||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
|
@ -430,6 +432,7 @@ class Dataset(_Dataset):
|
||||||
self.dataset_type = "training" if self.training else "validation"
|
self.dataset_type = "training" if self.training else "validation"
|
||||||
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
self.dataset = cfg.dataset.training if self.training else cfg.dataset.validation
|
||||||
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
self.sampler_type = cfg.dataset.sample_type # if self.dataset_type == "training" else "group"
|
||||||
|
self.sampler_order = cfg.dataset.sample_order
|
||||||
|
|
||||||
# to-do: do not do validation if there's nothing in the validation
|
# to-do: do not do validation if there's nothing in the validation
|
||||||
# this just makes it be happy
|
# this just makes it be happy
|
||||||
|
@ -446,10 +449,11 @@ class Dataset(_Dataset):
|
||||||
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
||||||
del self.paths_by_spkr_name[key]
|
del self.paths_by_spkr_name[key]
|
||||||
|
|
||||||
|
# flatten paths
|
||||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||||
|
|
||||||
# split dataset accordingly per GPU
|
# split dataset accordingly per GPU
|
||||||
if self.sampler_type == "path" and cfg.distributed and self.training:
|
if cfg.distributed and self.training:
|
||||||
batches = len(self.paths) // world_size()
|
batches = len(self.paths) // world_size()
|
||||||
start = batches * global_rank()
|
start = batches * global_rank()
|
||||||
end = batches * (global_rank() + 1)
|
end = batches * (global_rank() + 1)
|
||||||
|
@ -464,6 +468,40 @@ class Dataset(_Dataset):
|
||||||
self.paths_by_spkr_name[name] = []
|
self.paths_by_spkr_name[name] = []
|
||||||
self.paths_by_spkr_name[name].append( path )
|
self.paths_by_spkr_name[name].append( path )
|
||||||
|
|
||||||
|
# do it here due to the above
|
||||||
|
self.duration = 0
|
||||||
|
self.duration_map = _get_duration_map( self.dataset_type )
|
||||||
|
self.duration_buckets = {}
|
||||||
|
|
||||||
|
# store in corresponding bucket
|
||||||
|
for path in self.paths:
|
||||||
|
duration = self.duration_map[path]
|
||||||
|
self.duration += duration
|
||||||
|
|
||||||
|
# only calc duration if we're tot going to order by duration
|
||||||
|
if self.sampler_order != "duration":
|
||||||
|
continue
|
||||||
|
|
||||||
|
bucket = str(int(round(duration)))
|
||||||
|
if bucket not in self.duration_buckets:
|
||||||
|
self.duration_buckets[bucket] = []
|
||||||
|
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||||
|
|
||||||
|
# sort by duration
|
||||||
|
if self.sampler_order == "duration":
|
||||||
|
# sort and interleave
|
||||||
|
for bucket in self.duration_buckets:
|
||||||
|
# sort by duration
|
||||||
|
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
||||||
|
# replace with path
|
||||||
|
self.duration_buckets[bucket] = [ x[0] for x in self.duration_buckets[bucket] ]
|
||||||
|
# flatten by paths
|
||||||
|
self.duration_buckets[bucket] = [*_interleaved_reorder(self.duration_buckets[bucket], self.get_speaker)]
|
||||||
|
# flatten paths
|
||||||
|
self.paths = list(itertools.chain.from_iterable(self.duration_buckets.values()))
|
||||||
|
elif self.sampler_order == "shuffle":
|
||||||
|
# just interleave
|
||||||
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
|
||||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
|
|
||||||
|
@ -485,9 +523,6 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
if self.sampler_type == "path":
|
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
|
||||||
|
|
||||||
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||||
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
||||||
|
|
||||||
|
@ -504,15 +539,6 @@ class Dataset(_Dataset):
|
||||||
if len(self.paths) == 0:
|
if len(self.paths) == 0:
|
||||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
|
||||||
self.duration = _calculate_durations(self.dataset_type)
|
|
||||||
|
|
||||||
"""
|
|
||||||
@cached_property
|
|
||||||
def phones(self):
|
|
||||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_speaker(self, path):
|
def get_speaker(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user