updated dataloader to hopefully reduce RAM usage
This commit is contained in:
parent
9cfbf94b1c
commit
2053580838
|
@ -881,25 +881,26 @@ class Config(BaseConfig):
|
|||
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
|
||||
|
||||
def set_audio_backend(self, audio_backend):
|
||||
cfg.audio_backend = audio_backend
|
||||
self.audio_backend = audio_backend
|
||||
|
||||
audio_extension = None
|
||||
if audio_backend in ["encodec", "vocos"]:
|
||||
audio_extension = ".enc"
|
||||
cfg.sample_rate = 24_000
|
||||
#cfg.model.resp_levels = 8
|
||||
self.sample_rate = 24_000
|
||||
#self.model.resp_levels = 8
|
||||
elif audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
cfg.sample_rate = 44_100
|
||||
#cfg.model.resp_levels = 9
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
self.sample_rate = 44_100
|
||||
#self.model.resp_levels = 9
|
||||
elif self.audio_backend == "audiodec":
|
||||
audio_extension = ".dec"
|
||||
cfg.sample_rate = 48_000
|
||||
#cfg.model.resp_levels = 8 # ?
|
||||
elif cfg.audio_backend == "nemo":
|
||||
self.sample_rate = 48_000
|
||||
#self.model.resp_levels = 8 # ?
|
||||
elif self.audio_backend == "nemo":
|
||||
audio_extension = ".nem"
|
||||
cfg.sample_rate = 44_100
|
||||
#cfg.model.resp_levels = 8
|
||||
#cfg.model.audio_tokens = 1000
|
||||
self.sample_rate = 44_100
|
||||
#self.model.resp_levels = 8
|
||||
#self.model.audio_tokens = 1000
|
||||
else:
|
||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||
|
||||
|
|
456
vall_e/data.py
456
vall_e/data.py
|
@ -702,136 +702,75 @@ def _get_metadata_extension():
|
|||
def _get_artifact_path(path):
|
||||
return _replace_file_extension(path, _get_artifact_extension())
|
||||
|
||||
_durations_map = {}
|
||||
_similar_map = {}
|
||||
def _get_path_key( type, dir, id ):
|
||||
return f"/{type}/{_get_hdf5_path(dir)}/{id}" if cfg.dataset.use_hdf5 else str(dir / id)
|
||||
|
||||
def _get_duration_map( type="training" ):
|
||||
return _durations_map[type] if type in _durations_map else {}
|
||||
|
||||
def _get_similar_map( type="training" ):
|
||||
return _similar_map[type] if type in _similar_map else {}
|
||||
|
||||
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
|
||||
def _load_dataset_metadata(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
|
||||
assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low."
|
||||
|
||||
# for now only ensure metadata-based path
|
||||
assert cfg.dataset.use_metadata, "Metadata required."
|
||||
|
||||
if not dataset_hash_key:
|
||||
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
|
||||
|
||||
cached_dir = cfg.cache_dir / dataset_hash_key
|
||||
|
||||
cached_durations_path = cached_dir / f"durations[{type}].json"
|
||||
cached_paths_path = cached_dir / f"dataloader[{type}].json"
|
||||
cached_similar_path = cached_dir / f"similar[{type}].json"
|
||||
|
||||
# load the duration table first, since this is independent from the loaded paths
|
||||
if cached_durations_path.exists():
|
||||
_durations_map[type] = json_read( cached_durations_path )
|
||||
# load the similar paths table as well, since this is also independent
|
||||
if cached_similar_path.exists():
|
||||
_similar_map[type] = json_read( cached_similar_path )
|
||||
|
||||
# load the cached valid paths (if we're requesting cache use)
|
||||
if cached_paths_path.exists() and cfg.dataset.cache:
|
||||
# to-do: automatic conversion between HDF5 formatted paths and on-disk paths
|
||||
return json_read( cached_paths_path )
|
||||
cached_path = cached_dir / f"dataset[{type}].json"
|
||||
|
||||
# deduce valid paths
|
||||
paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
|
||||
if cached_path.exists() and cfg.dataset.cache:
|
||||
return json_read( cached_path )
|
||||
|
||||
dataset_metadata = {}
|
||||
def validate_utterance( id, entry ):
|
||||
duration = entry.get('duration', 0)
|
||||
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
|
||||
if cfg.dataset.validate and type == "training" and not in_bounds:
|
||||
return False
|
||||
|
||||
if cfg.dataset.strict_validate:
|
||||
if cfg.dataset.use_hdf5 and key(type, dir, id) not in cfg.hdf5:
|
||||
return False
|
||||
|
||||
if not (cfg.data_dir / dir / id).with_suffix(_get_artifact_extension()).exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def process_utterance( id, entry, metadata_keys=None ):
|
||||
duration = entry.get('duration', 0)
|
||||
similar = entry.get('similar', None)
|
||||
# store duration length, and similar key name (because indices might shift)
|
||||
return [duration, ([ metadata_keys[idx] for idx in similar ] if similar and metadata_keys else [])]
|
||||
|
||||
for dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent ):
|
||||
metadata_path = cfg.metadata_dir / f'{dir}.json'
|
||||
if not metadata_path.exists():
|
||||
continue
|
||||
|
||||
# to-do: make json_read handle when it actually can't read the file......
|
||||
try:
|
||||
metadata = json_read( metadata_path )
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
speaker = str(dir)
|
||||
metadata_keys = list(metadata.keys())
|
||||
dataset_metadata[speaker] = { id: process_utterance( id, entry, metadata_keys ) for id, entry in metadata.items() if validate_utterance( id, entry ) }
|
||||
|
||||
# remap strings to indices
|
||||
remapped_indices = { k: i for i, k in enumerate(dataset_metadata[speaker].keys()) }
|
||||
for id, (duration, similars) in dataset_metadata[speaker].items():
|
||||
dataset_metadata[speaker][id][1] = [ remapped_indices[k] for k in similars if k in remapped_indices ]
|
||||
|
||||
# and write if global leader (to avoid other processes writing to the same file at once)
|
||||
if is_global_leader():
|
||||
if not cached_dir.exists():
|
||||
cached_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
json_write( _similar_map[type], cached_similar_path, truncate=True )
|
||||
json_write( _durations_map[type], cached_durations_path, truncate=True )
|
||||
json_write( paths, cached_paths_path, truncate=True )
|
||||
json_write( dataset_metadata, cached_path, truncate=True )
|
||||
|
||||
return paths
|
||||
|
||||
def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
|
||||
|
||||
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
|
||||
|
||||
def key( id, entry=None ):
|
||||
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else str(data_dir / id)
|
||||
|
||||
metadata_path = cfg.metadata_dir / f'{group_name}.json'
|
||||
metadata = {}
|
||||
|
||||
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||
try:
|
||||
metadata = json_read( metadata_path )
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||
|
||||
# this might be slow
|
||||
def _exists( id, entry ):
|
||||
if not cfg.dataset.strict_validate:
|
||||
return True
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
return key(id, entry) in cfg.hdf5
|
||||
|
||||
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
|
||||
|
||||
metadata_keys = list(metadata.keys())
|
||||
def _validate( id, entry ):
|
||||
phones = entry.get('phones', 0)
|
||||
duration = entry.get('duration', 0)
|
||||
similar = entry.get('similar', None)
|
||||
|
||||
k = key(id, entry)
|
||||
|
||||
# add to duration bucket
|
||||
if type not in _durations_map:
|
||||
_durations_map[type] = {}
|
||||
_durations_map[type][k] = duration
|
||||
|
||||
# add to similar bucket
|
||||
if type not in _similar_map:
|
||||
_similar_map[type] = {}
|
||||
_similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None
|
||||
|
||||
if not validate:
|
||||
return True
|
||||
|
||||
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
if in_bounds and not _exists( id, entry ):
|
||||
return False
|
||||
|
||||
return in_bounds
|
||||
|
||||
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
|
||||
|
||||
|
||||
def _get_hdf5_path(path):
|
||||
# to-do: better validation
|
||||
return str(path)
|
||||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
data_dir = str(data_dir)
|
||||
|
||||
key = f"/{type}/{_get_hdf5_path(data_dir)}"
|
||||
|
||||
def _validate( id, entry ):
|
||||
phones = entry.attrs['phonemes']
|
||||
duration = entry.attrs['duration']
|
||||
|
||||
if type not in _durations_map:
|
||||
_durations_map[type] = {}
|
||||
_durations_map[type][f"{key}/{id}"] = float(duration)
|
||||
|
||||
if not validate:
|
||||
return True
|
||||
|
||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
|
||||
|
||||
return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||
return dataset_metadata
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
|
||||
if isinstance(path, str):
|
||||
|
@ -878,7 +817,7 @@ class Dataset(_Dataset):
|
|||
self,
|
||||
phone_symmap=None,
|
||||
training=False,
|
||||
extra_paths_by_spkr_name: dict[str, list] = {},
|
||||
extra_paths_by_speaker_name: dict[str, list] = {},
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -910,40 +849,31 @@ class Dataset(_Dataset):
|
|||
if self.sampler_order == "duration" and self.sampler_type != "path":
|
||||
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
|
||||
|
||||
# dict of paths keyed by speaker names
|
||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||
self.duration_map = _get_duration_map( self.dataset_type )
|
||||
|
||||
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
|
||||
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
|
||||
keys = list(self.paths_by_spkr_name.keys())
|
||||
for key in keys:
|
||||
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
|
||||
del self.paths_by_spkr_name[key]
|
||||
continue
|
||||
|
||||
# slice away extraneous utterances
|
||||
if cfg.dataset.max_utterances:
|
||||
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
|
||||
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||
# dict that maps [speaker][id] to (duration, similar utterances)
|
||||
self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
|
||||
|
||||
# cull speakers with too little utterances
|
||||
for speaker in self.metadata.keys():
|
||||
utterances = len(self.metadata[speaker])
|
||||
if utterances < cfg.dataset.min_utterances:
|
||||
del self.metadata[speaker]
|
||||
|
||||
self.paths = []
|
||||
self.speakers = list(self.metadata.keys())
|
||||
for speaker_id, speaker in enumerate(self.speakers):
|
||||
utterances = len(self.metadata[speaker])
|
||||
self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ]
|
||||
|
||||
# split dataset accordingly per GPU
|
||||
if cfg.distributed and self.training:
|
||||
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
|
||||
|
||||
# recreate paths_by_spkr_name
|
||||
self.paths_by_spkr_name = {}
|
||||
for path in self.paths:
|
||||
name = cfg.get_spkr( Path(path) )
|
||||
if name not in self.paths_by_spkr_name:
|
||||
self.paths_by_spkr_name[name] = []
|
||||
self.paths_by_spkr_name[name].append( path )
|
||||
|
||||
# store in corresponding bucket
|
||||
for path in self.paths:
|
||||
duration = self.duration_map[path]
|
||||
for (speaker_id, utterance_id) in self.paths:
|
||||
speaker = self.speakers[speaker_id]
|
||||
utterance = list(self.metadata[speaker].keys())[utterance_id]
|
||||
|
||||
duration, _ = self.metadata[speaker][utterance]
|
||||
self.duration += duration
|
||||
|
||||
# only calc duration if we're going to order by duration
|
||||
|
@ -953,7 +883,7 @@ class Dataset(_Dataset):
|
|||
bucket = int(round(duration))
|
||||
if bucket not in self.duration_buckets:
|
||||
self.duration_buckets[bucket] = []
|
||||
self.duration_buckets[bucket].append( ( Path(path), duration ) )
|
||||
self.duration_buckets[bucket].append( ( (speaker_id, utterance_id), duration ) )
|
||||
|
||||
# sort by duration
|
||||
if self.sampler_order == "duration":
|
||||
|
@ -964,43 +894,28 @@ class Dataset(_Dataset):
|
|||
# sort and interleave
|
||||
for bucket in self.duration_buckets:
|
||||
# sort by duration
|
||||
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
|
||||
self.duration_buckets[bucket].sort( key=lambda x: x[-1] )
|
||||
# split to retain tuples
|
||||
flattened[bucket] = self.duration_buckets[bucket]
|
||||
# replace with path
|
||||
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
|
||||
# flatten by paths
|
||||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
|
||||
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
|
||||
# flatten paths
|
||||
self.paths = list(itertools.chain.from_iterable(flattened.values()))
|
||||
elif self.sampler_order == "random":
|
||||
random.shuffle( self.paths )
|
||||
else:
|
||||
# just interleave
|
||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||
|
||||
# dict of speakers keyed by speaker group
|
||||
self.spkrs_by_spkr_group = {}
|
||||
for data_dir in self.dataset:
|
||||
spkr = cfg.get_spkr( data_dir / "dummy" )
|
||||
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
|
||||
|
||||
if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
|
||||
continue
|
||||
|
||||
if spkr_group not in self.spkrs_by_spkr_group:
|
||||
self.spkrs_by_spkr_group[spkr_group] = []
|
||||
|
||||
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
||||
|
||||
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
||||
self.paths = [*_interleaved_reorder(self.paths, lambda x: x[0])]
|
||||
|
||||
"""
|
||||
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
|
||||
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
|
||||
"""
|
||||
|
||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||
self.spkr_symmap = self._get_spkr_symmap()
|
||||
self.spkr_group_symmap = self._get_spkr_group_symmap()
|
||||
self.speaker_symmap = self._get_speaker_symmap()
|
||||
self.lang_symmap = self._get_lang_symmap()
|
||||
self.tone_symmap = self._get_tone_symmap()
|
||||
self.task_symmap = self._get_task_symmap()
|
||||
|
@ -1022,27 +937,16 @@ class Dataset(_Dataset):
|
|||
if len(self.paths) == 0:
|
||||
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||
|
||||
if self.sampler_type == "path" and self.training:
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler(
|
||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||
max_batch_size=self.batch_size,
|
||||
shuffle=self.sampler_shuffle,
|
||||
)
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||
self.samplers = {}
|
||||
self.spkr_samplers = {}
|
||||
if self.training and self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler(
|
||||
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
|
||||
max_duration=cfg.dataset.sample_max_duration_batch,
|
||||
max_batch_size=self.batch_size,
|
||||
shuffle=self.sampler_shuffle,
|
||||
)
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.sampler = RandomSampler( len(self) )
|
||||
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
|
||||
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||
|
||||
# dereference buckets
|
||||
self.duration_map = None
|
||||
self.duration_buckets = None
|
||||
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
|
||||
|
||||
self.load_state_dict()
|
||||
|
||||
|
@ -1053,13 +957,13 @@ class Dataset(_Dataset):
|
|||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
res = cfg.get_spkr(path)
|
||||
res = cfg.get_speaker(path)
|
||||
return res
|
||||
|
||||
def get_speaker_group(self, path):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
res = cfg.get_spkr_group(path)
|
||||
res = cfg.get_speaker_group(path)
|
||||
return res
|
||||
|
||||
# this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
|
||||
|
@ -1071,10 +975,6 @@ class Dataset(_Dataset):
|
|||
|
||||
return lang.lower()
|
||||
|
||||
@cached_property
|
||||
def spkrs(self):
|
||||
return sorted({self.get_speaker(path) for path in self.paths})
|
||||
|
||||
@cached_property
|
||||
def tasks(self):
|
||||
if not self.training:
|
||||
|
@ -1088,13 +988,16 @@ class Dataset(_Dataset):
|
|||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
state_dict = self.sampler.get_state()
|
||||
"""
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.get_state()
|
||||
else:
|
||||
state_dict = {
|
||||
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
||||
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
|
||||
"speaker_samplers": { name: sampler.get_state() for name, sampler in self.speaker_samplers.items() },
|
||||
}
|
||||
"""
|
||||
|
||||
if "dataset_hash_key" not in state_dict:
|
||||
state_dict["dataset_hash_key"] = self.dataset_hash_key
|
||||
|
@ -1117,6 +1020,8 @@ class Dataset(_Dataset):
|
|||
_logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.')
|
||||
return
|
||||
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
"""
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
else:
|
||||
|
@ -1125,19 +1030,17 @@ class Dataset(_Dataset):
|
|||
continue
|
||||
self.samplers[name].set_state( sampler )
|
||||
|
||||
for name, sampler in state_dict["spkr_samplers"].items():
|
||||
if name not in self.spkr_samplers:
|
||||
for name, sampler in state_dict["speaker_samplers"].items():
|
||||
if name not in self.speaker_samplers:
|
||||
continue
|
||||
self.spkr_samplers[name].set_state( sampler )
|
||||
self.speaker_samplers[name].set_state( sampler )
|
||||
"""
|
||||
|
||||
def _get_phone_symmap(self):
|
||||
return get_phone_symmap()
|
||||
|
||||
def _get_spkr_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkrs)}
|
||||
|
||||
def _get_spkr_group_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.spkr_groups)}
|
||||
def _get_speaker_symmap(self):
|
||||
return {s: i for i, s in enumerate(self.speakers)}
|
||||
|
||||
def _get_lang_symmap(self):
|
||||
return get_lang_symmap()
|
||||
|
@ -1159,17 +1062,19 @@ class Dataset(_Dataset):
|
|||
return qnt
|
||||
|
||||
def sample_speakers(self, ignore=[]):
|
||||
choices = set(self.spkrs) - set(ignore)
|
||||
choices = set(self.speakers) - set(ignore)
|
||||
return random.choice([*choices])
|
||||
|
||||
def sample_utterance(self, spkr_name, ignore=[]):
|
||||
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - set(ignore))]
|
||||
def sample_utterance(self, speaker_name, ignore=[]):
|
||||
choices = [*(set(self.metadata[speaker_name].keys()) - set(ignore))]
|
||||
|
||||
if len(choices) == 0:
|
||||
return None, None, None
|
||||
|
||||
path = random.choice(choices)
|
||||
utterance_id = random.choice(choices)
|
||||
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
|
||||
|
||||
path = cfg.data_dir / speaker_name / utterance_name
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
|
@ -1201,13 +1106,11 @@ class Dataset(_Dataset):
|
|||
return path, text, resps
|
||||
|
||||
# icky slop
|
||||
def get_similar_utterance(self, path, offset=None ):
|
||||
def get_similar_utterance(self, speaker_name, utterance_name, offset=None ):
|
||||
if offset is None:
|
||||
offset = cfg.dataset.prompt_similar_top_k_offset
|
||||
|
||||
root = Path( *path.parts[:-1] )
|
||||
reference = path.name
|
||||
similars = _similar_map[self.dataset_type].get(str(path), None)
|
||||
_, similars = self.metadata[speaker_name][utterance_name]
|
||||
|
||||
if not similars:
|
||||
return None
|
||||
|
@ -1223,41 +1126,28 @@ class Dataset(_Dataset):
|
|||
if offset_end >= len( similars ):
|
||||
return None
|
||||
|
||||
utterance_keys = list(self.metadata[speaker_name].keys())
|
||||
if cfg.dataset.prompt_similar_top_k > 1:
|
||||
name = random.choice( similars[offset:offset_end] )
|
||||
index = random.choice( similars[offset:offset_end] )
|
||||
else:
|
||||
name = similars[offset]
|
||||
index = similars[offset]
|
||||
|
||||
path = root / name
|
||||
return utterance_keys[index]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
if key not in cfg.hdf5[key]:
|
||||
return None
|
||||
elif not path.exists():
|
||||
return None
|
||||
|
||||
return path
|
||||
|
||||
def sample_prompts(self, spkr_name, reference, should_trim=True):
|
||||
def sample_prompts(self, speaker_name, utterance_name=None, should_trim=True):
|
||||
# return no prompt if explicitly requested for who knows why
|
||||
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
|
||||
if len(self.paths_by_spkr_name[spkr_name]) <= 1:
|
||||
if len(self.metadata[speaker_name]) <= 1:
|
||||
return None
|
||||
|
||||
prom_list = []
|
||||
|
||||
choices = set(self.paths_by_spkr_name[spkr_name]) - {reference}
|
||||
choices = set(self.metadata[speaker_name].keys()) - {utterance_name}
|
||||
choices = [*choices]
|
||||
|
||||
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
|
||||
if len(choices) == 0:
|
||||
choices = [*set(self.paths_by_spkr_name[spkr_name])]
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Failed to find another different utterance for {spkr_name}."
|
||||
)
|
||||
"""
|
||||
choices = [*set(self.metadata[speaker_name].keys())]
|
||||
|
||||
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
|
||||
should_trim = False
|
||||
|
@ -1271,17 +1161,18 @@ class Dataset(_Dataset):
|
|||
|
||||
for _ in range(cfg.dataset.prompt_max_samples):
|
||||
# yuck
|
||||
path = None
|
||||
if reference is not None:
|
||||
sampled_utterance = None
|
||||
if utterance_name is not None:
|
||||
if random.random() < cfg.dataset.prompt_similar_p:
|
||||
try:
|
||||
path = self.get_similar_utterance( reference, offset = len(prom_list) )
|
||||
sampled_utterance = self.get_similar_utterance( speaker_name, utterance_name, offset = len(prom_list) )
|
||||
except Exception as e:
|
||||
path = None
|
||||
sampled_utterance = None
|
||||
|
||||
if not path:
|
||||
path = random.choice(choices)
|
||||
if not sampled_utterance:
|
||||
sampled_utterance = random.choice(choices)
|
||||
|
||||
path = cfg.data_dir / speaker_name / sampled_utterance
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
if key not in cfg.hdf5:
|
||||
|
@ -1294,6 +1185,7 @@ class Dataset(_Dataset):
|
|||
except Exception as e:
|
||||
_logger.warning(f'Failed to load artifact: {path} ({e})')
|
||||
path = None
|
||||
continue
|
||||
|
||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
@ -1321,27 +1213,10 @@ class Dataset(_Dataset):
|
|||
|
||||
bos_id, space_id, eos_id = self.empty_text
|
||||
|
||||
if self.sampler_type == "group":
|
||||
spkr_group = self.spkr_groups[index]
|
||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
spkr_name = self.spkr_samplers[spkr_group].sample()
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = self.samplers[spkr_name].sample()
|
||||
elif self.sampler_type == "speaker":
|
||||
spkr_name = self.spkrs[index]
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
path = self.samplers[spkr_name].sample()
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
else:
|
||||
path = self.paths[index]
|
||||
spkr_name = self.get_speaker(path)
|
||||
spkr_id = self.spkr_symmap[spkr_name]
|
||||
spkr_group = self.get_speaker_group(path)
|
||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
|
||||
if not isinstance( path, Path ):
|
||||
path = Path( path )
|
||||
speaker_id, utterance_id = self.paths[index]
|
||||
speaker_name = self.speakers[speaker_id]
|
||||
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
|
||||
path = cfg.data_dir / speaker_name / utterance_name
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
@ -1378,7 +1253,7 @@ class Dataset(_Dataset):
|
|||
tone = metadata["tone"] if "tone" in metadata else None
|
||||
text_string = metadata["text"] if "text" in metadata else None
|
||||
|
||||
lang = self.get_language(spkr_group) if not lang else lang.lower()
|
||||
lang = lang.lower() if lang else "en"
|
||||
|
||||
raw_text = torch.tensor(text_tokenize(text_string)).to(torch.int16) if text_string else None
|
||||
|
||||
|
@ -1398,7 +1273,7 @@ class Dataset(_Dataset):
|
|||
if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p:
|
||||
ignore_paths = []
|
||||
for _ in range( 1, cfg.dataset.resps_max_samples ):
|
||||
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
|
||||
path, txt, qnt = self.sample_utterance(speaker_name, ignore=ignore_paths)
|
||||
ignore_paths.append(path)
|
||||
|
||||
# <s>[original text]</s><s>[new text]</s>
|
||||
|
@ -1412,7 +1287,7 @@ class Dataset(_Dataset):
|
|||
# might be better to decode => concat waveforms with silence in between => reencode
|
||||
# as you technically can't just append encodec sequences together like this without issues
|
||||
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
||||
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
if task not in self.task_symmap:
|
||||
|
@ -1420,8 +1295,10 @@ class Dataset(_Dataset):
|
|||
|
||||
# Base TTS (<text><prompt> => <resp>)
|
||||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, reference=path)
|
||||
proms = self.sample_prompts(speaker_name, utterance_name)
|
||||
|
||||
"""
|
||||
proms = self.sample_prompts(speaker_name, reference=path)
|
||||
if random.random() < cfg.dataset.prompt_inject_noise_p:
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
|
@ -1429,7 +1306,7 @@ class Dataset(_Dataset):
|
|||
noise = repeat_extend_audio(noise, proms.shape[0])
|
||||
# create the input prompt by merging the target audio with the noise
|
||||
proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
|
||||
|
||||
"""
|
||||
|
||||
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
|
||||
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
|
||||
|
@ -1442,7 +1319,7 @@ class Dataset(_Dataset):
|
|||
proms = resps[:trim_length, :]
|
||||
resps = resps[trim_length:, :]
|
||||
else:
|
||||
path, txt, qnt = self.sample_utterance(spkr_name)
|
||||
path, txt, qnt = self.sample_utterance(speaker_name)
|
||||
|
||||
# <s>[original text]</s><s>[new text]</s>
|
||||
if naive:
|
||||
|
@ -1471,7 +1348,7 @@ class Dataset(_Dataset):
|
|||
|
||||
# Duration prediction (<text><prompt> => len(<resp>))
|
||||
elif task == "len":
|
||||
proms = self.sample_prompts(spkr_name, reference=path)
|
||||
proms = self.sample_prompts(speaker_name, utterance_name)
|
||||
|
||||
elif task in ["phn", "un-phn"]:
|
||||
proms = []
|
||||
|
@ -1503,10 +1380,10 @@ class Dataset(_Dataset):
|
|||
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
|
||||
elif task == "tse":
|
||||
# sample a prompt
|
||||
proms = self.sample_prompts(spkr_name, reference=path)
|
||||
proms = self.sample_prompts(speaker_name, utterance_name)
|
||||
|
||||
# sample another speaker
|
||||
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
|
||||
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[speaker_name]))
|
||||
|
||||
# overlay the random speaker over the target audio
|
||||
other_resps = merge_audio( resps, other_resps, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
|
||||
|
@ -1531,7 +1408,7 @@ class Dataset(_Dataset):
|
|||
|
||||
samples = []
|
||||
for _ in range( 4 ):
|
||||
sampled = self.sample_utterance(spkr_name, ignore=[s[0] for s in samples])
|
||||
sampled = self.sample_utterance(speaker_name, ignore=[s[0] for s in samples])
|
||||
samples.append( sampled )
|
||||
|
||||
pre_text, mid_text, post_text, edit_text = [ s[1][1:-1] for s in samples ]
|
||||
|
@ -1611,8 +1488,8 @@ class Dataset(_Dataset):
|
|||
return dict(
|
||||
index=index,
|
||||
path=Path(path),
|
||||
spkr_name=spkr_name,
|
||||
spkr_id=spkr_id,
|
||||
speaker_name=speaker_name,
|
||||
speaker_id=speaker_id,
|
||||
task=task,
|
||||
lang=lang,
|
||||
tone=tone,
|
||||
|
@ -1640,10 +1517,8 @@ class Dataset(_Dataset):
|
|||
return len(self.sampler if self.sampler is not None else self) // self.batch_size
|
||||
|
||||
def __len__(self):
|
||||
if self.sampler_type == "group":
|
||||
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||
if self.sampler_type == "speaker":
|
||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||
return min(len(self.speakers), self._head or len(self.speakers))
|
||||
return min(len(self.paths), self._head or len(self.paths))
|
||||
|
||||
|
||||
|
@ -1690,8 +1565,7 @@ def create_train_dataloader():
|
|||
train_dl = _create_dataloader(train_dataset, training=True)
|
||||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(str(train_dataset.spkr_symmap))
|
||||
_logger.info(str(train_dataset.spkr_group_symmap))
|
||||
_logger.info(str(train_dataset.speaker_symmap))
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
|
||||
|
@ -1706,8 +1580,7 @@ def create_val_dataloader():
|
|||
val_dl = _create_dataloader(val_dataset, training=False)
|
||||
|
||||
_logger.info(str(val_dataset.phone_symmap))
|
||||
_logger.info(str(val_dataset.spkr_symmap))
|
||||
_logger.info(str(val_dataset.spkr_group_symmap))
|
||||
_logger.info(str(val_dataset.speaker_symmap))
|
||||
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
|
||||
|
@ -1724,8 +1597,7 @@ def create_train_val_dataloader():
|
|||
val_dl = _create_dataloader(val_dataset, training=False)
|
||||
|
||||
_logger.info(str(train_dataset.phone_symmap))
|
||||
_logger.info(f'#speakers (train): {len(train_dataset.spkr_symmap)}')
|
||||
_logger.info(f'#groups (train): {len(train_dataset.spkr_group_symmap)}')
|
||||
_logger.info(f'#speakers (train): {len(train_dataset.speaker_symmap)}')
|
||||
|
||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||
|
@ -2018,7 +1890,6 @@ if __name__ == "__main__":
|
|||
|
||||
samples = {
|
||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
|
||||
|
@ -2044,31 +1915,18 @@ if __name__ == "__main__":
|
|||
for i in range(len(v)):
|
||||
_logger.info(f'{k}[{i}]: {v[i]}')
|
||||
elif args.action == "validate":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
train_dl, val_dl = create_train_val_dataloader()
|
||||
dataset = train_dl.dataset
|
||||
|
||||
missing = []
|
||||
symmap = get_phone_symmap()
|
||||
|
||||
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
|
||||
if dataset.sampler_type == "group":
|
||||
spkr_group = dataset.spkr_groups[index]
|
||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||
spkr_name = dataset.spkr_samplers[spkr_group].sample()
|
||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
||||
path = dataset.samplers[spkr_name].sample()
|
||||
elif dataset.sampler_type == "speaker":
|
||||
spkr_name = dataset.spkrs[index]
|
||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
||||
path = dataset.samplers[spkr_name].sample()
|
||||
spkr_group = dataset.get_speaker_group(path)
|
||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||
else:
|
||||
path = dataset.paths[index]
|
||||
spkr_name = dataset.get_speaker(path)
|
||||
spkr_id = dataset.spkr_symmap[spkr_name]
|
||||
spkr_group = dataset.get_speaker_group(path)
|
||||
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
|
||||
speaker_id, utterance_id = dataset.paths[index]
|
||||
speaker_name = dataset.speakers[speaker_id]
|
||||
speaker_keys = list(dataset.metadata[speaker_name].keys())
|
||||
utterance_name = speaker_keys[utterance_id]
|
||||
path = cfg.data_dir / speaker_name / utterance_name
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
@ -2113,7 +1971,7 @@ if __name__ == "__main__":
|
|||
index = 0
|
||||
cfg.dataset.tasks_list = args.tasks.split(",")
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
train_dl, val_dl = create_train_val_dataloader()
|
||||
batch = next(iter(train_dl))
|
||||
|
||||
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):
|
||||
|
|
Loading…
Reference in New Issue
Block a user