updated dataloader to hopefully reduce RAM usage
This commit is contained in:
parent
9cfbf94b1c
commit
2053580838
|
@ -881,25 +881,26 @@ class Config(BaseConfig):
|
||||||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||||
|
|
||||||
def set_audio_backend(self, audio_backend):
|
def set_audio_backend(self, audio_backend):
|
||||||
cfg.audio_backend = audio_backend
|
self.audio_backend = audio_backend
|
||||||
|
|
||||||
audio_extension = None
|
audio_extension = None
|
||||||
if audio_backend in ["encodec", "vocos"]:
|
if audio_backend in ["encodec", "vocos"]:
|
||||||
audio_extension = ".enc"
|
audio_extension = ".enc"
|
||||||
cfg.sample_rate = 24_000
|
self.sample_rate = 24_000
|
||||||
#cfg.model.resp_levels = 8
|
#self.model.resp_levels = 8
|
||||||
elif audio_backend == "dac":
|
elif audio_backend == "dac":
|
||||||
audio_extension = ".dac"
|
audio_extension = ".dac"
|
||||||
cfg.sample_rate = 44_100
|
self.sample_rate = 44_100
|
||||||
#cfg.model.resp_levels = 9
|
#self.model.resp_levels = 9
|
||||||
elif cfg.audio_backend == "audiodec":
|
elif self.audio_backend == "audiodec":
|
||||||
audio_extension = ".dec"
|
audio_extension = ".dec"
|
||||||
cfg.sample_rate = 48_000
|
self.sample_rate = 48_000
|
||||||
#cfg.model.resp_levels = 8 # ?
|
#self.model.resp_levels = 8 # ?
|
||||||
elif cfg.audio_backend == "nemo":
|
elif self.audio_backend == "nemo":
|
||||||
audio_extension = ".nem"
|
audio_extension = ".nem"
|
||||||
cfg.sample_rate = 44_100
|
self.sample_rate = 44_100
|
||||||
#cfg.model.resp_levels = 8
|
#self.model.resp_levels = 8
|
||||||
#cfg.model.audio_tokens = 1000
|
#self.model.audio_tokens = 1000
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||||
|
|
||||||
|
|
428
vall_e/data.py
428
vall_e/data.py
|
@ -702,136 +702,75 @@ def _get_metadata_extension():
|
||||||
def _get_artifact_path(path):
|
def _get_artifact_path(path):
|
||||||
return _replace_file_extension(path, _get_artifact_extension())
|
return _replace_file_extension(path, _get_artifact_extension())
|
||||||
|
|
||||||
_durations_map = {}
|
def _get_path_key( type, dir, id ):
|
||||||
_similar_map = {}
|
return f"/{type}/{_get_hdf5_path(dir)}/{id}" if cfg.dataset.use_hdf5 else str(dir / id)
|
||||||
|
|
||||||
def _get_duration_map( type="training" ):
|
def _load_dataset_metadata(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
|
||||||
return _durations_map[type] if type in _durations_map else {}
|
|
||||||
|
|
||||||
def _get_similar_map( type="training" ):
|
|
||||||
return _similar_map[type] if type in _similar_map else {}
|
|
||||||
|
|
||||||
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
|
|
||||||
assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low."
|
assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low."
|
||||||
|
|
||||||
|
# for now only ensure metadata-based path
|
||||||
|
assert cfg.dataset.use_metadata, "Metadata required."
|
||||||
|
|
||||||
if not dataset_hash_key:
|
if not dataset_hash_key:
|
||||||
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
||||||
|
|
||||||
cached_dir = cfg.cache_dir / dataset_hash_key
|
cached_dir = cfg.cache_dir / dataset_hash_key
|
||||||
|
cached_path = cached_dir / f"dataset[{type}].json"
|
||||||
|
|
||||||
cached_durations_path = cached_dir / f"durations[{type}].json"
|
if cached_path.exists() and cfg.dataset.cache:
|
||||||
cached_paths_path = cached_dir / f"dataloader[{type}].json"
|
return json_read( cached_path )
|
||||||
cached_similar_path = cached_dir / f"similar[{type}].json"
|
|
||||||
|
|
||||||
# load the duration table first, since this is independent from the loaded paths
|
dataset_metadata = {}
|
||||||
if cached_durations_path.exists():
|
def validate_utterance( id, entry ):
|
||||||
_durations_map[type] = json_read( cached_durations_path )
|
duration = entry.get('duration', 0)
|
||||||
# load the similar paths table as well, since this is also independent
|
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||||
if cached_similar_path.exists():
|
|
||||||
_similar_map[type] = json_read( cached_similar_path )
|
|
||||||
|
|
||||||
# load the cached valid paths (if we're requesting cache use)
|
if cfg.dataset.validate and type == "training" and not in_bounds:
|
||||||
if cached_paths_path.exists() and cfg.dataset.cache:
|
return False
|
||||||
# to-do: automatic conversion between HDF5 formatted paths and on-disk paths
|
|
||||||
return json_read( cached_paths_path )
|
|
||||||
|
|
||||||
# deduce valid paths
|
if cfg.dataset.strict_validate:
|
||||||
paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
|
if cfg.dataset.use_hdf5 and key(type, dir, id) not in cfg.hdf5:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not (cfg.data_dir / dir / id).with_suffix(_get_artifact_extension()).exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def process_utterance( id, entry, metadata_keys=None ):
|
||||||
|
duration = entry.get('duration', 0)
|
||||||
|
similar = entry.get('similar', None)
|
||||||
|
# store duration length, and similar key name (because indices might shift)
|
||||||
|
return [duration, ([ metadata_keys[idx] for idx in similar ] if similar and metadata_keys else [])]
|
||||||
|
|
||||||
|
for dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent ):
|
||||||
|
metadata_path = cfg.metadata_dir / f'{dir}.json'
|
||||||
|
if not metadata_path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# to-do: make json_read handle when it actually can't read the file......
|
||||||
|
try:
|
||||||
|
metadata = json_read( metadata_path )
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
speaker = str(dir)
|
||||||
|
metadata_keys = list(metadata.keys())
|
||||||
|
dataset_metadata[speaker] = { id: process_utterance( id, entry, metadata_keys ) for id, entry in metadata.items() if validate_utterance( id, entry ) }
|
||||||
|
|
||||||
|
# remap strings to indices
|
||||||
|
remapped_indices = { k: i for i, k in enumerate(dataset_metadata[speaker].keys()) }
|
||||||
|
for id, (duration, similars) in dataset_metadata[speaker].items():
|
||||||
|
dataset_metadata[speaker][id][1] = [ remapped_indices[k] for k in similars if k in remapped_indices ]
|
||||||
|
|
||||||
# and write if global leader (to avoid other processes writing to the same file at once)
|
# and write if global leader (to avoid other processes writing to the same file at once)
|
||||||
if is_global_leader():
|
if is_global_leader():
|
||||||
if not cached_dir.exists():
|
if not cached_dir.exists():
|
||||||
cached_dir.mkdir(parents=True, exist_ok=True)
|
cached_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
json_write( _similar_map[type], cached_similar_path, truncate=True )
|
json_write( dataset_metadata, cached_path, truncate=True )
|
||||||
json_write( _durations_map[type], cached_durations_path, truncate=True )
|
|
||||||
json_write( paths, cached_paths_path, truncate=True )
|
|
||||||
|
|
||||||
return paths
|
return dataset_metadata
|
||||||
|
|
||||||
def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|
||||||
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
|
|
||||||
|
|
||||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
|
||||||
|
|
||||||
def key( id, entry=None ):
|
|
||||||
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else str(data_dir / id)
|
|
||||||
|
|
||||||
metadata_path = cfg.metadata_dir / f'{group_name}.json'
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
|
||||||
try:
|
|
||||||
metadata = json_read( metadata_path )
|
|
||||||
except Exception as e:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if len(metadata) == 0:
|
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
|
||||||
|
|
||||||
# this might be slow
|
|
||||||
def _exists( id, entry ):
|
|
||||||
if not cfg.dataset.strict_validate:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
|
||||||
return key(id, entry) in cfg.hdf5
|
|
||||||
|
|
||||||
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
|
|
||||||
|
|
||||||
metadata_keys = list(metadata.keys())
|
|
||||||
def _validate( id, entry ):
|
|
||||||
phones = entry.get('phones', 0)
|
|
||||||
duration = entry.get('duration', 0)
|
|
||||||
similar = entry.get('similar', None)
|
|
||||||
|
|
||||||
k = key(id, entry)
|
|
||||||
|
|
||||||
# add to duration bucket
|
|
||||||
if type not in _durations_map:
|
|
||||||
_durations_map[type] = {}
|
|
||||||
_durations_map[type][k] = duration
|
|
||||||
|
|
||||||
# add to similar bucket
|
|
||||||
if type not in _similar_map:
|
|
||||||
_similar_map[type] = {}
|
|
||||||
_similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None
|
|
||||||
|
|
||||||
if not validate:
|
|
||||||
return True
|
|
||||||
|
|
||||||
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
|
||||||
if in_bounds and not _exists( id, entry ):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return in_bounds
|
|
||||||
|
|
||||||
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_hdf5_path(path):
|
|
||||||
# to-do: better validation
|
|
||||||
return str(path)
|
|
||||||
|
|
||||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|
||||||
data_dir = str(data_dir)
|
|
||||||
|
|
||||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
|
||||||
|
|
||||||
def _validate( id, entry ):
|
|
||||||
phones = entry.attrs['phonemes']
|
|
||||||
duration = entry.attrs['duration']
|
|
||||||
|
|
||||||
if type not in _durations_map:
|
|
||||||
_durations_map[type] = {}
|
|
||||||
_durations_map[type][f"{key}/{id}"] = float(duration)
|
|
||||||
|
|
||||||
if not validate:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
|
||||||
|
|
||||||
return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
|
||||||
|
|
||||||
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
|
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
|
@ -878,7 +817,7 @@ class Dataset(_Dataset):
|
||||||
self,
|
self,
|
||||||
phone_symmap=None,
|
phone_symmap=None,
|
||||||
training=False,
|
training=False,
|
||||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
extra_paths_by_speaker_name: dict[str, list] = {},
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -910,40 +849,31 @@ class Dataset(_Dataset):
|
||||||
if self.sampler_order == "duration" and self.sampler_type != "path":
|
if self.sampler_order == "duration" and self.sampler_type != "path":
|
||||||
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
||||||
|
|
||||||
# dict of paths keyed by speaker names
|
# dict that maps [speaker][id] to (duration, similar utterances)
|
||||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||||
self.duration_map = _get_duration_map( self.dataset_type )
|
|
||||||
|
|
||||||
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
|
# cull speakers with too little utterances
|
||||||
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
|
for speaker in self.metadata.keys():
|
||||||
keys = list(self.paths_by_spkr_name.keys())
|
utterances = len(self.metadata[speaker])
|
||||||
for key in keys:
|
if utterances < cfg.dataset.min_utterances:
|
||||||
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
del self.metadata[speaker]
|
||||||
del self.paths_by_spkr_name[key]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# slice away extraneous utterances
|
self.paths = []
|
||||||
if cfg.dataset.max_utterances:
|
self.speakers = list(self.metadata.keys())
|
||||||
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
|
for speaker_id, speaker in enumerate(self.speakers):
|
||||||
|
utterances = len(self.metadata[speaker])
|
||||||
# flatten paths
|
self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ]
|
||||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
|
||||||
|
|
||||||
# split dataset accordingly per GPU
|
# split dataset accordingly per GPU
|
||||||
if cfg.distributed and self.training:
|
if cfg.distributed and self.training:
|
||||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||||
|
|
||||||
# recreate paths_by_spkr_name
|
|
||||||
self.paths_by_spkr_name = {}
|
|
||||||
for path in self.paths:
|
|
||||||
name = cfg.get_spkr( Path(path) )
|
|
||||||
if name not in self.paths_by_spkr_name:
|
|
||||||
self.paths_by_spkr_name[name] = []
|
|
||||||
self.paths_by_spkr_name[name].append( path )
|
|
||||||
|
|
||||||
# store in corresponding bucket
|
# store in corresponding bucket
|
||||||
for path in self.paths:
|
for (speaker_id, utterance_id) in self.paths:
|
||||||
duration = self.duration_map[path]
|
speaker = self.speakers[speaker_id]
|
||||||
|
utterance = list(self.metadata[speaker].keys())[utterance_id]
|
||||||
|
|
||||||
|
duration, _ = self.metadata[speaker][utterance]
|
||||||
self.duration += duration
|
self.duration += duration
|
||||||
|
|
||||||
# only calc duration if we're going to order by duration
|
# only calc duration if we're going to order by duration
|
||||||
|
@ -953,7 +883,7 @@ class Dataset(_Dataset):
|
||||||
bucket = int(round(duration))
|
bucket = int(round(duration))
|
||||||
if bucket not in self.duration_buckets:
|
if bucket not in self.duration_buckets:
|
||||||
self.duration_buckets[bucket] = []
|
self.duration_buckets[bucket] = []
|
||||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
self.duration_buckets[bucket].append( ( (speaker_id, utterance_id), duration ) )
|
||||||
|
|
||||||
# sort by duration
|
# sort by duration
|
||||||
if self.sampler_order == "duration":
|
if self.sampler_order == "duration":
|
||||||
|
@ -964,43 +894,28 @@ class Dataset(_Dataset):
|
||||||
# sort and interleave
|
# sort and interleave
|
||||||
for bucket in self.duration_buckets:
|
for bucket in self.duration_buckets:
|
||||||
# sort by duration
|
# sort by duration
|
||||||
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
self.duration_buckets[bucket].sort( key=lambda x: x[-1] )
|
||||||
# split to retain tuples
|
# split to retain tuples
|
||||||
flattened[bucket] = self.duration_buckets[bucket]
|
flattened[bucket] = self.duration_buckets[bucket]
|
||||||
# replace with path
|
# replace with path
|
||||||
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||||
# flatten by paths
|
# flatten by paths
|
||||||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
|
||||||
# flatten paths
|
# flatten paths
|
||||||
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||||
elif self.sampler_order == "random":
|
elif self.sampler_order == "random":
|
||||||
random.shuffle( self.paths )
|
random.shuffle( self.paths )
|
||||||
else:
|
else:
|
||||||
# just interleave
|
# just interleave
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, lambda x: x[0])]
|
||||||
|
|
||||||
# dict of speakers keyed by speaker group
|
|
||||||
self.spkrs_by_spkr_group = {}
|
|
||||||
for data_dir in self.dataset:
|
|
||||||
spkr = cfg.get_spkr( data_dir / "dummy" )
|
|
||||||
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
|
|
||||||
|
|
||||||
if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if spkr_group not in self.spkrs_by_spkr_group:
|
|
||||||
self.spkrs_by_spkr_group[spkr_group] = []
|
|
||||||
|
|
||||||
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
|
||||||
|
|
||||||
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
|
||||||
|
|
||||||
|
"""
|
||||||
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||||
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
||||||
|
"""
|
||||||
|
|
||||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||||
self.spkr_symmap = self._get_spkr_symmap()
|
self.speaker_symmap = self._get_speaker_symmap()
|
||||||
self.spkr_group_symmap = self._get_spkr_group_symmap()
|
|
||||||
self.lang_symmap = self._get_lang_symmap()
|
self.lang_symmap = self._get_lang_symmap()
|
||||||
self.tone_symmap = self._get_tone_symmap()
|
self.tone_symmap = self._get_tone_symmap()
|
||||||
self.task_symmap = self._get_task_symmap()
|
self.task_symmap = self._get_task_symmap()
|
||||||
|
@ -1022,8 +937,7 @@ class Dataset(_Dataset):
|
||||||
if len(self.paths) == 0:
|
if len(self.paths) == 0:
|
||||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
if self.sampler_type == "path" and self.training:
|
if self.training and self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
|
||||||
self.sampler = BatchedOrderedSampler(
|
self.sampler = BatchedOrderedSampler(
|
||||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||||
|
@ -1033,16 +947,6 @@ class Dataset(_Dataset):
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
else:
|
else:
|
||||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||||
self.samplers = {}
|
|
||||||
self.spkr_samplers = {}
|
|
||||||
else:
|
|
||||||
self.sampler = RandomSampler( len(self) )
|
|
||||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
|
||||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
|
||||||
|
|
||||||
# dereference buckets
|
|
||||||
self.duration_map = None
|
|
||||||
self.duration_buckets = None
|
|
||||||
|
|
||||||
self.load_state_dict()
|
self.load_state_dict()
|
||||||
|
|
||||||
|
@ -1053,13 +957,13 @@ class Dataset(_Dataset):
|
||||||
def get_speaker(self, path):
|
def get_speaker(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
res = cfg.get_spkr(path)
|
res = cfg.get_speaker(path)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def get_speaker_group(self, path):
|
def get_speaker_group(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
res = cfg.get_spkr_group(path)
|
res = cfg.get_speaker_group(path)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
|
# this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
|
||||||
|
@ -1071,10 +975,6 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
return lang.lower()
|
return lang.lower()
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def spkrs(self):
|
|
||||||
return sorted({self.get_speaker(path) for path in self.paths})
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def tasks(self):
|
def tasks(self):
|
||||||
if not self.training:
|
if not self.training:
|
||||||
|
@ -1088,13 +988,16 @@ class Dataset(_Dataset):
|
||||||
if not path.parent.exists():
|
if not path.parent.exists():
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
state_dict = self.sampler.get_state()
|
||||||
|
"""
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
state_dict = self.sampler.get_state()
|
state_dict = self.sampler.get_state()
|
||||||
else:
|
else:
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
||||||
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
|
"speaker_samplers": { name: sampler.get_state() for name, sampler in self.speaker_samplers.items() },
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
if "dataset_hash_key" not in state_dict:
|
if "dataset_hash_key" not in state_dict:
|
||||||
state_dict["dataset_hash_key"] = self.dataset_hash_key
|
state_dict["dataset_hash_key"] = self.dataset_hash_key
|
||||||
|
@ -1117,6 +1020,8 @@ class Dataset(_Dataset):
|
||||||
_logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.')
|
_logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
state_dict = self.sampler.set_state(state_dict)
|
||||||
|
"""
|
||||||
if self.sampler_type == "path":
|
if self.sampler_type == "path":
|
||||||
state_dict = self.sampler.set_state(state_dict)
|
state_dict = self.sampler.set_state(state_dict)
|
||||||
else:
|
else:
|
||||||
|
@ -1125,19 +1030,17 @@ class Dataset(_Dataset):
|
||||||
continue
|
continue
|
||||||
self.samplers[name].set_state( sampler )
|
self.samplers[name].set_state( sampler )
|
||||||
|
|
||||||
for name, sampler in state_dict["spkr_samplers"].items():
|
for name, sampler in state_dict["speaker_samplers"].items():
|
||||||
if name not in self.spkr_samplers:
|
if name not in self.speaker_samplers:
|
||||||
continue
|
continue
|
||||||
self.spkr_samplers[name].set_state( sampler )
|
self.speaker_samplers[name].set_state( sampler )
|
||||||
|
"""
|
||||||
|
|
||||||
def _get_phone_symmap(self):
|
def _get_phone_symmap(self):
|
||||||
return get_phone_symmap()
|
return get_phone_symmap()
|
||||||
|
|
||||||
def _get_spkr_symmap(self):
|
def _get_speaker_symmap(self):
|
||||||
return {s: i for i, s in enumerate(self.spkrs)}
|
return {s: i for i, s in enumerate(self.speakers)}
|
||||||
|
|
||||||
def _get_spkr_group_symmap(self):
|
|
||||||
return {s: i for i, s in enumerate(self.spkr_groups)}
|
|
||||||
|
|
||||||
def _get_lang_symmap(self):
|
def _get_lang_symmap(self):
|
||||||
return get_lang_symmap()
|
return get_lang_symmap()
|
||||||
|
@ -1159,17 +1062,19 @@ class Dataset(_Dataset):
|
||||||
return qnt
|
return qnt
|
||||||
|
|
||||||
def sample_speakers(self, ignore=[]):
|
def sample_speakers(self, ignore=[]):
|
||||||
choices = set(self.spkrs) - set(ignore)
|
choices = set(self.speakers) - set(ignore)
|
||||||
return random.choice([*choices])
|
return random.choice([*choices])
|
||||||
|
|
||||||
def sample_utterance(self, spkr_name, ignore=[]):
|
def sample_utterance(self, speaker_name, ignore=[]):
|
||||||
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - set(ignore))]
|
choices = [*(set(self.metadata[speaker_name].keys()) - set(ignore))]
|
||||||
|
|
||||||
if len(choices) == 0:
|
if len(choices) == 0:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
path = random.choice(choices)
|
utterance_id = random.choice(choices)
|
||||||
|
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
|
||||||
|
|
||||||
|
path = cfg.data_dir / speaker_name / utterance_name
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
|
@ -1201,13 +1106,11 @@ class Dataset(_Dataset):
|
||||||
return path, text, resps
|
return path, text, resps
|
||||||
|
|
||||||
# icky slop
|
# icky slop
|
||||||
def get_similar_utterance(self, path, offset=None ):
|
def get_similar_utterance(self, speaker_name, utterance_name, offset=None ):
|
||||||
if offset is None:
|
if offset is None:
|
||||||
offset = cfg.dataset.prompt_similar_top_k_offset
|
offset = cfg.dataset.prompt_similar_top_k_offset
|
||||||
|
|
||||||
root = Path( *path.parts[:-1] )
|
_, similars = self.metadata[speaker_name][utterance_name]
|
||||||
reference = path.name
|
|
||||||
similars = _similar_map[self.dataset_type].get(str(path), None)
|
|
||||||
|
|
||||||
if not similars:
|
if not similars:
|
||||||
return None
|
return None
|
||||||
|
@ -1223,41 +1126,28 @@ class Dataset(_Dataset):
|
||||||
if offset_end >= len( similars ):
|
if offset_end >= len( similars ):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
utterance_keys = list(self.metadata[speaker_name].keys())
|
||||||
if cfg.dataset.prompt_similar_top_k > 1:
|
if cfg.dataset.prompt_similar_top_k > 1:
|
||||||
name = random.choice( similars[offset:offset_end] )
|
index = random.choice( similars[offset:offset_end] )
|
||||||
else:
|
else:
|
||||||
name = similars[offset]
|
index = similars[offset]
|
||||||
|
|
||||||
path = root / name
|
return utterance_keys[index]
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
def sample_prompts(self, speaker_name, utterance_name=None, should_trim=True):
|
||||||
key = _get_hdf5_path(path)
|
|
||||||
if key not in cfg.hdf5[key]:
|
|
||||||
return None
|
|
||||||
elif not path.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
|
||||||
# return no prompt if explicitly requested for who knows why
|
# return no prompt if explicitly requested for who knows why
|
||||||
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
||||||
if len(self.paths_by_spkr_name[spkr_name]) <= 1:
|
if len(self.metadata[speaker_name]) <= 1:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
prom_list = []
|
prom_list = []
|
||||||
|
|
||||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {reference}
|
choices = set(self.metadata[speaker_name].keys()) - {utterance_name}
|
||||||
choices = [*choices]
|
choices = [*choices]
|
||||||
|
|
||||||
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
|
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
|
||||||
if len(choices) == 0:
|
if len(choices) == 0:
|
||||||
choices = [*set(self.paths_by_spkr_name[spkr_name])]
|
choices = [*set(self.metadata[speaker_name].keys())]
|
||||||
"""
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to find another different utterance for {spkr_name}."
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
|
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
|
||||||
should_trim = False
|
should_trim = False
|
||||||
|
@ -1271,17 +1161,18 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
for _ in range(cfg.dataset.prompt_max_samples):
|
for _ in range(cfg.dataset.prompt_max_samples):
|
||||||
# yuck
|
# yuck
|
||||||
path = None
|
sampled_utterance = None
|
||||||
if reference is not None:
|
if utterance_name is not None:
|
||||||
if random.random() < cfg.dataset.prompt_similar_p:
|
if random.random() < cfg.dataset.prompt_similar_p:
|
||||||
try:
|
try:
|
||||||
path = self.get_similar_utterance( reference, offset = len(prom_list) )
|
sampled_utterance = self.get_similar_utterance( speaker_name, utterance_name, offset = len(prom_list) )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
path = None
|
sampled_utterance = None
|
||||||
|
|
||||||
if not path:
|
if not sampled_utterance:
|
||||||
path = random.choice(choices)
|
sampled_utterance = random.choice(choices)
|
||||||
|
|
||||||
|
path = cfg.data_dir / speaker_name / sampled_utterance
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
if key not in cfg.hdf5:
|
if key not in cfg.hdf5:
|
||||||
|
@ -1294,6 +1185,7 @@ class Dataset(_Dataset):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f'Failed to load artifact: {path} ({e})')
|
_logger.warning(f'Failed to load artifact: {path} ({e})')
|
||||||
path = None
|
path = None
|
||||||
|
continue
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||||
|
@ -1321,27 +1213,10 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
bos_id, space_id, eos_id = self.empty_text
|
bos_id, space_id, eos_id = self.empty_text
|
||||||
|
|
||||||
if self.sampler_type == "group":
|
speaker_id, utterance_id = self.paths[index]
|
||||||
spkr_group = self.spkr_groups[index]
|
speaker_name = self.speakers[speaker_id]
|
||||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
|
||||||
spkr_name = self.spkr_samplers[spkr_group].sample()
|
path = cfg.data_dir / speaker_name / utterance_name
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
|
||||||
path = self.samplers[spkr_name].sample()
|
|
||||||
elif self.sampler_type == "speaker":
|
|
||||||
spkr_name = self.spkrs[index]
|
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
|
||||||
path = self.samplers[spkr_name].sample()
|
|
||||||
spkr_group = self.get_speaker_group(path)
|
|
||||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
|
||||||
else:
|
|
||||||
path = self.paths[index]
|
|
||||||
spkr_name = self.get_speaker(path)
|
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
|
||||||
spkr_group = self.get_speaker_group(path)
|
|
||||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
|
||||||
|
|
||||||
if not isinstance( path, Path ):
|
|
||||||
path = Path( path )
|
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
@ -1378,7 +1253,7 @@ class Dataset(_Dataset):
|
||||||
tone = metadata["tone"] if "tone" in metadata else None
|
tone = metadata["tone"] if "tone" in metadata else None
|
||||||
text_string = metadata["text"] if "text" in metadata else None
|
text_string = metadata["text"] if "text" in metadata else None
|
||||||
|
|
||||||
lang = self.get_language(spkr_group) if not lang else lang.lower()
|
lang = lang.lower() if lang else "en"
|
||||||
|
|
||||||
raw_text = torch.tensor(text_tokenize(text_string)).to(torch.int16) if text_string else None
|
raw_text = torch.tensor(text_tokenize(text_string)).to(torch.int16) if text_string else None
|
||||||
|
|
||||||
|
@ -1398,7 +1273,7 @@ class Dataset(_Dataset):
|
||||||
if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p:
|
if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p:
|
||||||
ignore_paths = []
|
ignore_paths = []
|
||||||
for _ in range( 1, cfg.dataset.resps_max_samples ):
|
for _ in range( 1, cfg.dataset.resps_max_samples ):
|
||||||
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
path, txt, qnt = self.sample_utterance(speaker_name, ignore=ignore_paths)
|
||||||
ignore_paths.append(path)
|
ignore_paths.append(path)
|
||||||
|
|
||||||
# <s>[original text]</s><s>[new text]</s>
|
# <s>[original text]</s><s>[new text]</s>
|
||||||
|
@ -1420,8 +1295,10 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# Base TTS (<text><prompt> => <resp>)
|
# Base TTS (<text><prompt> => <resp>)
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
proms = self.sample_prompts(spkr_name, reference=path)
|
proms = self.sample_prompts(speaker_name, utterance_name)
|
||||||
|
|
||||||
|
"""
|
||||||
|
proms = self.sample_prompts(speaker_name, reference=path)
|
||||||
if random.random() < cfg.dataset.prompt_inject_noise_p:
|
if random.random() < cfg.dataset.prompt_inject_noise_p:
|
||||||
# sample random noise
|
# sample random noise
|
||||||
noise = self.sample_noise()
|
noise = self.sample_noise()
|
||||||
|
@ -1429,7 +1306,7 @@ class Dataset(_Dataset):
|
||||||
noise = repeat_extend_audio(noise, proms.shape[0])
|
noise = repeat_extend_audio(noise, proms.shape[0])
|
||||||
# create the input prompt by merging the target audio with the noise
|
# create the input prompt by merging the target audio with the noise
|
||||||
proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||||
|
"""
|
||||||
|
|
||||||
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
|
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
|
||||||
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
|
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
|
||||||
|
@ -1442,7 +1319,7 @@ class Dataset(_Dataset):
|
||||||
proms = resps[:trim_length, :]
|
proms = resps[:trim_length, :]
|
||||||
resps = resps[trim_length:, :]
|
resps = resps[trim_length:, :]
|
||||||
else:
|
else:
|
||||||
path, txt, qnt = self.sample_utterance(spkr_name)
|
path, txt, qnt = self.sample_utterance(speaker_name)
|
||||||
|
|
||||||
# <s>[original text]</s><s>[new text]</s>
|
# <s>[original text]</s><s>[new text]</s>
|
||||||
if naive:
|
if naive:
|
||||||
|
@ -1471,7 +1348,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# Duration prediction (<text><prompt> => len(<resp>))
|
# Duration prediction (<text><prompt> => len(<resp>))
|
||||||
elif task == "len":
|
elif task == "len":
|
||||||
proms = self.sample_prompts(spkr_name, reference=path)
|
proms = self.sample_prompts(speaker_name, utterance_name)
|
||||||
|
|
||||||
elif task in ["phn", "un-phn"]:
|
elif task in ["phn", "un-phn"]:
|
||||||
proms = []
|
proms = []
|
||||||
|
@ -1503,10 +1380,10 @@ class Dataset(_Dataset):
|
||||||
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
||||||
elif task == "tse":
|
elif task == "tse":
|
||||||
# sample a prompt
|
# sample a prompt
|
||||||
proms = self.sample_prompts(spkr_name, reference=path)
|
proms = self.sample_prompts(speaker_name, utterance_name)
|
||||||
|
|
||||||
# sample another speaker
|
# sample another speaker
|
||||||
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
|
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[speaker_name]))
|
||||||
|
|
||||||
# overlay the random speaker over the target audio
|
# overlay the random speaker over the target audio
|
||||||
other_resps = merge_audio( resps, other_resps, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
other_resps = merge_audio( resps, other_resps, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||||
|
@ -1531,7 +1408,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
for _ in range( 4 ):
|
for _ in range( 4 ):
|
||||||
sampled = self.sample_utterance(spkr_name, ignore=[s[0] for s in samples])
|
sampled = self.sample_utterance(speaker_name, ignore=[s[0] for s in samples])
|
||||||
samples.append( sampled )
|
samples.append( sampled )
|
||||||
|
|
||||||
pre_text, mid_text, post_text, edit_text = [ s[1][1:-1] for s in samples ]
|
pre_text, mid_text, post_text, edit_text = [ s[1][1:-1] for s in samples ]
|
||||||
|
@ -1611,8 +1488,8 @@ class Dataset(_Dataset):
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
path=Path(path),
|
path=Path(path),
|
||||||
spkr_name=spkr_name,
|
speaker_name=speaker_name,
|
||||||
spkr_id=spkr_id,
|
speaker_id=speaker_id,
|
||||||
task=task,
|
task=task,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
tone=tone,
|
tone=tone,
|
||||||
|
@ -1640,10 +1517,8 @@ class Dataset(_Dataset):
|
||||||
return len(self.sampler if self.sampler is not None else self) // self.batch_size
|
return len(self.sampler if self.sampler is not None else self) // self.batch_size
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self.sampler_type == "group":
|
|
||||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
|
||||||
if self.sampler_type == "speaker":
|
if self.sampler_type == "speaker":
|
||||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
return min(len(self.speakers), self._head or len(self.speakers))
|
||||||
return min(len(self.paths), self._head or len(self.paths))
|
return min(len(self.paths), self._head or len(self.paths))
|
||||||
|
|
||||||
|
|
||||||
|
@ -1690,8 +1565,7 @@ def create_train_dataloader():
|
||||||
train_dl = _create_dataloader(train_dataset, training=True)
|
train_dl = _create_dataloader(train_dataset, training=True)
|
||||||
|
|
||||||
_logger.info(str(train_dataset.phone_symmap))
|
_logger.info(str(train_dataset.phone_symmap))
|
||||||
_logger.info(str(train_dataset.spkr_symmap))
|
_logger.info(str(train_dataset.speaker_symmap))
|
||||||
_logger.info(str(train_dataset.spkr_group_symmap))
|
|
||||||
|
|
||||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||||
|
@ -1706,8 +1580,7 @@ def create_val_dataloader():
|
||||||
val_dl = _create_dataloader(val_dataset, training=False)
|
val_dl = _create_dataloader(val_dataset, training=False)
|
||||||
|
|
||||||
_logger.info(str(val_dataset.phone_symmap))
|
_logger.info(str(val_dataset.phone_symmap))
|
||||||
_logger.info(str(val_dataset.spkr_symmap))
|
_logger.info(str(val_dataset.speaker_symmap))
|
||||||
_logger.info(str(val_dataset.spkr_group_symmap))
|
|
||||||
|
|
||||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||||
|
@ -1724,8 +1597,7 @@ def create_train_val_dataloader():
|
||||||
val_dl = _create_dataloader(val_dataset, training=False)
|
val_dl = _create_dataloader(val_dataset, training=False)
|
||||||
|
|
||||||
_logger.info(str(train_dataset.phone_symmap))
|
_logger.info(str(train_dataset.phone_symmap))
|
||||||
_logger.info(f'#speakers (train): {len(train_dataset.spkr_symmap)}')
|
_logger.info(f'#speakers (train): {len(train_dataset.speaker_symmap)}')
|
||||||
_logger.info(f'#groups (train): {len(train_dataset.spkr_group_symmap)}')
|
|
||||||
|
|
||||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
|
@ -2018,7 +1890,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
samples = {
|
samples = {
|
||||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||||
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
|
||||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2044,31 +1915,18 @@ if __name__ == "__main__":
|
||||||
for i in range(len(v)):
|
for i in range(len(v)):
|
||||||
_logger.info(f'{k}[{i}]: {v[i]}')
|
_logger.info(f'{k}[{i}]: {v[i]}')
|
||||||
elif args.action == "validate":
|
elif args.action == "validate":
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, val_dl = create_train_val_dataloader()
|
||||||
dataset = train_dl.dataset
|
dataset = train_dl.dataset
|
||||||
|
|
||||||
missing = []
|
missing = []
|
||||||
symmap = get_phone_symmap()
|
symmap = get_phone_symmap()
|
||||||
|
|
||||||
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
|
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
|
||||||
if dataset.sampler_type == "group":
|
speaker_id, utterance_id = dataset.paths[index]
|
||||||
spkr_group = dataset.spkr_groups[index]
|
speaker_name = dataset.speakers[speaker_id]
|
||||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
speaker_keys = list(dataset.metadata[speaker_name].keys())
|
||||||
spkr_name = dataset.spkr_samplers[spkr_group].sample()
|
utterance_name = speaker_keys[utterance_id]
|
||||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
path = cfg.data_dir / speaker_name / utterance_name
|
||||||
path = dataset.samplers[spkr_name].sample()
|
|
||||||
elif dataset.sampler_type == "speaker":
|
|
||||||
spkr_name = dataset.spkrs[index]
|
|
||||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
|
||||||
path = dataset.samplers[spkr_name].sample()
|
|
||||||
spkr_group = dataset.get_speaker_group(path)
|
|
||||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
|
||||||
else:
|
|
||||||
path = dataset.paths[index]
|
|
||||||
spkr_name = dataset.get_speaker(path)
|
|
||||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
|
||||||
spkr_group = dataset.get_speaker_group(path)
|
|
||||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
@ -2113,7 +1971,7 @@ if __name__ == "__main__":
|
||||||
index = 0
|
index = 0
|
||||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, val_dl = create_train_val_dataloader()
|
||||||
batch = next(iter(train_dl))
|
batch = next(iter(train_dl))
|
||||||
|
|
||||||
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user