updated dataloader to hopefully reduce RAM usage

This commit is contained in:
mrq 2025-03-15 13:14:37 -05:00
parent 9cfbf94b1c
commit 2053580838
2 changed files with 170 additions and 311 deletions

View File

@ -881,25 +881,26 @@ class Config(BaseConfig):
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
def set_audio_backend(self, audio_backend):
cfg.audio_backend = audio_backend
self.audio_backend = audio_backend
audio_extension = None
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
cfg.sample_rate = 24_000
#cfg.model.resp_levels = 8
self.sample_rate = 24_000
#self.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
#cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
self.sample_rate = 44_100
#self.model.resp_levels = 9
elif self.audio_backend == "audiodec":
audio_extension = ".dec"
cfg.sample_rate = 48_000
#cfg.model.resp_levels = 8 # ?
elif cfg.audio_backend == "nemo":
self.sample_rate = 48_000
#self.model.resp_levels = 8 # ?
elif self.audio_backend == "nemo":
audio_extension = ".nem"
cfg.sample_rate = 44_100
#cfg.model.resp_levels = 8
#cfg.model.audio_tokens = 1000
self.sample_rate = 44_100
#self.model.resp_levels = 8
#self.model.audio_tokens = 1000
else:
raise Exception(f"Unknown audio backend: {audio_backend}")

View File

@ -702,136 +702,75 @@ def _get_metadata_extension():
def _get_artifact_path(path):
return _replace_file_extension(path, _get_artifact_extension())
_durations_map = {}
_similar_map = {}
def _get_path_key( type, dir, id ):
return f"/{type}/{_get_hdf5_path(dir)}/{id}" if cfg.dataset.use_hdf5 else str(dir / id)
def _get_duration_map( type="training" ):
return _durations_map[type] if type in _durations_map else {}
def _get_similar_map( type="training" ):
return _similar_map[type] if type in _similar_map else {}
def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
def _load_dataset_metadata(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low."
# for now only ensure metadata-based path
assert cfg.dataset.use_metadata, "Metadata required."
if not dataset_hash_key:
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
cached_dir = cfg.cache_dir / dataset_hash_key
cached_durations_path = cached_dir / f"durations[{type}].json"
cached_paths_path = cached_dir / f"dataloader[{type}].json"
cached_similar_path = cached_dir / f"similar[{type}].json"
# load the duration table first, since this is independent from the loaded paths
if cached_durations_path.exists():
_durations_map[type] = json_read( cached_durations_path )
# load the similar paths table as well, since this is also independent
if cached_similar_path.exists():
_similar_map[type] = json_read( cached_similar_path )
# load the cached valid paths (if we're requesting cache use)
if cached_paths_path.exists() and cfg.dataset.cache:
# to-do: automatic conversion between HDF5 formatted paths and on-disk paths
return json_read( cached_paths_path )
cached_path = cached_dir / f"dataset[{type}].json"
# deduce valid paths
paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
if cached_path.exists() and cfg.dataset.cache:
return json_read( cached_path )
dataset_metadata = {}
def validate_utterance( id, entry ):
duration = entry.get('duration', 0)
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
if cfg.dataset.validate and type == "training" and not in_bounds:
return False
if cfg.dataset.strict_validate:
if cfg.dataset.use_hdf5 and key(type, dir, id) not in cfg.hdf5:
return False
if not (cfg.data_dir / dir / id).with_suffix(_get_artifact_extension()).exists():
return False
return True
def process_utterance( id, entry, metadata_keys=None ):
duration = entry.get('duration', 0)
similar = entry.get('similar', None)
# store duration length, and similar key name (because indices might shift)
return [duration, ([ metadata_keys[idx] for idx in similar ] if similar and metadata_keys else [])]
for dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent ):
metadata_path = cfg.metadata_dir / f'{dir}.json'
if not metadata_path.exists():
continue
# to-do: make json_read handle when it actually can't read the file......
try:
metadata = json_read( metadata_path )
except Exception as e:
continue
speaker = str(dir)
metadata_keys = list(metadata.keys())
dataset_metadata[speaker] = { id: process_utterance( id, entry, metadata_keys ) for id, entry in metadata.items() if validate_utterance( id, entry ) }
# remap strings to indices
remapped_indices = { k: i for i, k in enumerate(dataset_metadata[speaker].keys()) }
for id, (duration, similars) in dataset_metadata[speaker].items():
dataset_metadata[speaker][id][1] = [ remapped_indices[k] for k in similars if k in remapped_indices ]
# and write if global leader (to avoid other processes writing to the same file at once)
if is_global_leader():
if not cached_dir.exists():
cached_dir.mkdir(parents=True, exist_ok=True)
json_write( _similar_map[type], cached_similar_path, truncate=True )
json_write( _durations_map[type], cached_durations_path, truncate=True )
json_write( paths, cached_paths_path, truncate=True )
json_write( dataset_metadata, cached_path, truncate=True )
return paths
def _load_paths_from_metadata(group_name, type="training", validate=False):
data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
_fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
def key( id, entry=None ):
return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else str(data_dir / id)
metadata_path = cfg.metadata_dir / f'{group_name}.json'
metadata = {}
if cfg.dataset.use_metadata and metadata_path.exists():
try:
metadata = json_read( metadata_path )
except Exception as e:
return {}
if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
# this might be slow
def _exists( id, entry ):
if not cfg.dataset.strict_validate:
return True
if cfg.dataset.use_hdf5:
return key(id, entry) in cfg.hdf5
return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
metadata_keys = list(metadata.keys())
def _validate( id, entry ):
phones = entry.get('phones', 0)
duration = entry.get('duration', 0)
similar = entry.get('similar', None)
k = key(id, entry)
# add to duration bucket
if type not in _durations_map:
_durations_map[type] = {}
_durations_map[type][k] = duration
# add to similar bucket
if type not in _similar_map:
_similar_map[type] = {}
_similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None
if not validate:
return True
in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
if in_bounds and not _exists( id, entry ):
return False
return in_bounds
return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
def _get_hdf5_path(path):
# to-do: better validation
return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ):
data_dir = str(data_dir)
key = f"/{type}/{_get_hdf5_path(data_dir)}"
def _validate( id, entry ):
phones = entry.attrs['phonemes']
duration = entry.attrs['duration']
if type not in _durations_map:
_durations_map[type] = {}
_durations_map[type][f"{key}/{id}"] = float(duration)
if not validate:
return True
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
return dataset_metadata
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
if isinstance(path, str):
@ -878,7 +817,7 @@ class Dataset(_Dataset):
self,
phone_symmap=None,
training=False,
extra_paths_by_spkr_name: dict[str, list] = {},
extra_paths_by_speaker_name: dict[str, list] = {},
):
super().__init__()
@ -910,40 +849,31 @@ class Dataset(_Dataset):
if self.sampler_order == "duration" and self.sampler_type != "path":
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
# dict of paths keyed by speaker names
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
self.duration_map = _get_duration_map( self.dataset_type )
# cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
keys = list(self.paths_by_spkr_name.keys())
for key in keys:
if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
del self.paths_by_spkr_name[key]
continue
# slice away extraneous utterances
if cfg.dataset.max_utterances:
self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
# flatten paths
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
# dict that maps [speaker][id] to (duration, similar utterances)
self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
# cull speakers with too little utterances
for speaker in self.metadata.keys():
utterances = len(self.metadata[speaker])
if utterances < cfg.dataset.min_utterances:
del self.metadata[speaker]
self.paths = []
self.speakers = list(self.metadata.keys())
for speaker_id, speaker in enumerate(self.speakers):
utterances = len(self.metadata[speaker])
self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ]
# split dataset accordingly per GPU
if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
# recreate paths_by_spkr_name
self.paths_by_spkr_name = {}
for path in self.paths:
name = cfg.get_spkr( Path(path) )
if name not in self.paths_by_spkr_name:
self.paths_by_spkr_name[name] = []
self.paths_by_spkr_name[name].append( path )
# store in corresponding bucket
for path in self.paths:
duration = self.duration_map[path]
for (speaker_id, utterance_id) in self.paths:
speaker = self.speakers[speaker_id]
utterance = list(self.metadata[speaker].keys())[utterance_id]
duration, _ = self.metadata[speaker][utterance]
self.duration += duration
# only calc duration if we're going to order by duration
@ -953,7 +883,7 @@ class Dataset(_Dataset):
bucket = int(round(duration))
if bucket not in self.duration_buckets:
self.duration_buckets[bucket] = []
self.duration_buckets[bucket].append( ( Path(path), duration ) )
self.duration_buckets[bucket].append( ( (speaker_id, utterance_id), duration ) )
# sort by duration
if self.sampler_order == "duration":
@ -964,43 +894,28 @@ class Dataset(_Dataset):
# sort and interleave
for bucket in self.duration_buckets:
# sort by duration
self.duration_buckets[bucket].sort( key=lambda x: x[1] )
self.duration_buckets[bucket].sort( key=lambda x: x[-1] )
# split to retain tuples
flattened[bucket] = self.duration_buckets[bucket]
# replace with path
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
# flatten by paths
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
# flatten paths
self.paths = list(itertools.chain.from_iterable(flattened.values()))
elif self.sampler_order == "random":
random.shuffle( self.paths )
else:
# just interleave
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
# dict of speakers keyed by speaker group
self.spkrs_by_spkr_group = {}
for data_dir in self.dataset:
spkr = cfg.get_spkr( data_dir / "dummy" )
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
continue
if spkr_group not in self.spkrs_by_spkr_group:
self.spkrs_by_spkr_group[spkr_group] = []
self.spkrs_by_spkr_group[spkr_group].append( spkr )
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
self.paths = [*_interleaved_reorder(self.paths, lambda x: x[0])]
"""
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
"""
self.phone_symmap = phone_symmap or self._get_phone_symmap()
self.spkr_symmap = self._get_spkr_symmap()
self.spkr_group_symmap = self._get_spkr_group_symmap()
self.speaker_symmap = self._get_speaker_symmap()
self.lang_symmap = self._get_lang_symmap()
self.tone_symmap = self._get_tone_symmap()
self.task_symmap = self._get_task_symmap()
@ -1022,27 +937,16 @@ class Dataset(_Dataset):
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
if self.sampler_type == "path" and self.training:
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler(
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
max_duration=cfg.dataset.sample_max_duration_batch,
max_batch_size=self.batch_size,
shuffle=self.sampler_shuffle,
)
self.batch_size = 1
else:
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
self.samplers = {}
self.spkr_samplers = {}
if self.training and self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler(
self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
max_duration=cfg.dataset.sample_max_duration_batch,
max_batch_size=self.batch_size,
shuffle=self.sampler_shuffle,
)
self.batch_size = 1
else:
self.sampler = RandomSampler( len(self) )
self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
# dereference buckets
self.duration_map = None
self.duration_buckets = None
self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
self.load_state_dict()
@ -1053,13 +957,13 @@ class Dataset(_Dataset):
def get_speaker(self, path):
if isinstance(path, str):
path = Path(path)
res = cfg.get_spkr(path)
res = cfg.get_speaker(path)
return res
def get_speaker_group(self, path):
if isinstance(path, str):
path = Path(path)
res = cfg.get_spkr_group(path)
res = cfg.get_speaker_group(path)
return res
# this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
@ -1071,10 +975,6 @@ class Dataset(_Dataset):
return lang.lower()
@cached_property
def spkrs(self):
return sorted({self.get_speaker(path) for path in self.paths})
@cached_property
def tasks(self):
if not self.training:
@ -1088,13 +988,16 @@ class Dataset(_Dataset):
if not path.parent.exists():
path.parent.mkdir(parents=True, exist_ok=True)
state_dict = self.sampler.get_state()
"""
if self.sampler_type == "path":
state_dict = self.sampler.get_state()
else:
state_dict = {
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
"speaker_samplers": { name: sampler.get_state() for name, sampler in self.speaker_samplers.items() },
}
"""
if "dataset_hash_key" not in state_dict:
state_dict["dataset_hash_key"] = self.dataset_hash_key
@ -1117,6 +1020,8 @@ class Dataset(_Dataset):
_logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.')
return
state_dict = self.sampler.set_state(state_dict)
"""
if self.sampler_type == "path":
state_dict = self.sampler.set_state(state_dict)
else:
@ -1125,19 +1030,17 @@ class Dataset(_Dataset):
continue
self.samplers[name].set_state( sampler )
for name, sampler in state_dict["spkr_samplers"].items():
if name not in self.spkr_samplers:
for name, sampler in state_dict["speaker_samplers"].items():
if name not in self.speaker_samplers:
continue
self.spkr_samplers[name].set_state( sampler )
self.speaker_samplers[name].set_state( sampler )
"""
def _get_phone_symmap(self):
return get_phone_symmap()
def _get_spkr_symmap(self):
return {s: i for i, s in enumerate(self.spkrs)}
def _get_spkr_group_symmap(self):
return {s: i for i, s in enumerate(self.spkr_groups)}
def _get_speaker_symmap(self):
return {s: i for i, s in enumerate(self.speakers)}
def _get_lang_symmap(self):
return get_lang_symmap()
@ -1159,17 +1062,19 @@ class Dataset(_Dataset):
return qnt
def sample_speakers(self, ignore=[]):
choices = set(self.spkrs) - set(ignore)
choices = set(self.speakers) - set(ignore)
return random.choice([*choices])
def sample_utterance(self, spkr_name, ignore=[]):
choices = [*(set(self.paths_by_spkr_name[spkr_name]) - set(ignore))]
def sample_utterance(self, speaker_name, ignore=[]):
choices = [*(set(self.metadata[speaker_name].keys()) - set(ignore))]
if len(choices) == 0:
return None, None, None
path = random.choice(choices)
utterance_id = random.choice(choices)
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
path = cfg.data_dir / speaker_name / utterance_name
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@ -1201,13 +1106,11 @@ class Dataset(_Dataset):
return path, text, resps
# icky slop
def get_similar_utterance(self, path, offset=None ):
def get_similar_utterance(self, speaker_name, utterance_name, offset=None ):
if offset is None:
offset = cfg.dataset.prompt_similar_top_k_offset
root = Path( *path.parts[:-1] )
reference = path.name
similars = _similar_map[self.dataset_type].get(str(path), None)
_, similars = self.metadata[speaker_name][utterance_name]
if not similars:
return None
@ -1223,41 +1126,28 @@ class Dataset(_Dataset):
if offset_end >= len( similars ):
return None
utterance_keys = list(self.metadata[speaker_name].keys())
if cfg.dataset.prompt_similar_top_k > 1:
name = random.choice( similars[offset:offset_end] )
index = random.choice( similars[offset:offset_end] )
else:
name = similars[offset]
index = similars[offset]
path = root / name
return utterance_keys[index]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5[key]:
return None
elif not path.exists():
return None
return path
def sample_prompts(self, spkr_name, reference, should_trim=True):
def sample_prompts(self, speaker_name, utterance_name=None, should_trim=True):
# return no prompt if explicitly requested for who knows why
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
if len(self.paths_by_spkr_name[spkr_name]) <= 1:
if len(self.metadata[speaker_name]) <= 1:
return None
prom_list = []
choices = set(self.paths_by_spkr_name[spkr_name]) - {reference}
choices = set(self.metadata[speaker_name].keys()) - {utterance_name}
choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
if len(choices) == 0:
choices = [*set(self.paths_by_spkr_name[spkr_name])]
"""
raise ValueError(
f"Failed to find another different utterance for {spkr_name}."
)
"""
choices = [*set(self.metadata[speaker_name].keys())]
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
should_trim = False
@ -1271,17 +1161,18 @@ class Dataset(_Dataset):
for _ in range(cfg.dataset.prompt_max_samples):
# yuck
path = None
if reference is not None:
sampled_utterance = None
if utterance_name is not None:
if random.random() < cfg.dataset.prompt_similar_p:
try:
path = self.get_similar_utterance( reference, offset = len(prom_list) )
sampled_utterance = self.get_similar_utterance( speaker_name, utterance_name, offset = len(prom_list) )
except Exception as e:
path = None
sampled_utterance = None
if not path:
path = random.choice(choices)
if not sampled_utterance:
sampled_utterance = random.choice(choices)
path = cfg.data_dir / speaker_name / sampled_utterance
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
@ -1294,6 +1185,7 @@ class Dataset(_Dataset):
except Exception as e:
_logger.warning(f'Failed to load artifact: {path} ({e})')
path = None
continue
if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
@ -1321,27 +1213,10 @@ class Dataset(_Dataset):
bos_id, space_id, eos_id = self.empty_text
if self.sampler_type == "group":
spkr_group = self.spkr_groups[index]
#spkr_group_id = self.spkr_group_symmap[spkr_group]
spkr_name = self.spkr_samplers[spkr_group].sample()
spkr_id = self.spkr_symmap[spkr_name]
path = self.samplers[spkr_name].sample()
elif self.sampler_type == "speaker":
spkr_name = self.spkrs[index]
spkr_id = self.spkr_symmap[spkr_name]
path = self.samplers[spkr_name].sample()
spkr_group = self.get_speaker_group(path)
#spkr_group_id = self.spkr_group_symmap[spkr_group]
else:
path = self.paths[index]
spkr_name = self.get_speaker(path)
spkr_id = self.spkr_symmap[spkr_name]
spkr_group = self.get_speaker_group(path)
#spkr_group_id = self.spkr_group_symmap[spkr_group]
if not isinstance( path, Path ):
path = Path( path )
speaker_id, utterance_id = self.paths[index]
speaker_name = self.speakers[speaker_id]
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
path = cfg.data_dir / speaker_name / utterance_name
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@ -1378,7 +1253,7 @@ class Dataset(_Dataset):
tone = metadata["tone"] if "tone" in metadata else None
text_string = metadata["text"] if "text" in metadata else None
lang = self.get_language(spkr_group) if not lang else lang.lower()
lang = lang.lower() if lang else "en"
raw_text = torch.tensor(text_tokenize(text_string)).to(torch.int16) if text_string else None
@ -1398,7 +1273,7 @@ class Dataset(_Dataset):
if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p:
ignore_paths = []
for _ in range( 1, cfg.dataset.resps_max_samples ):
path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
path, txt, qnt = self.sample_utterance(speaker_name, ignore=ignore_paths)
ignore_paths.append(path)
# <s>[original text]</s><s>[new text]</s>
@ -1412,7 +1287,7 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
task = random.choice(self.tasks)
if task not in self.task_symmap:
@ -1420,8 +1295,10 @@ class Dataset(_Dataset):
# Base TTS (<text><prompt> => <resp>)
if task == "tts":
proms = self.sample_prompts(spkr_name, reference=path)
proms = self.sample_prompts(speaker_name, utterance_name)
"""
proms = self.sample_prompts(speaker_name, reference=path)
if random.random() < cfg.dataset.prompt_inject_noise_p:
# sample random noise
noise = self.sample_noise()
@ -1429,7 +1306,7 @@ class Dataset(_Dataset):
noise = repeat_extend_audio(noise, proms.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
"""
# VALL-E Continuous (<text><partial resp> => <remaining resp> )
# (this could just be sampled as <text a><text b><audio a> => <audio b>, but I need to experiment with it)
@ -1442,7 +1319,7 @@ class Dataset(_Dataset):
proms = resps[:trim_length, :]
resps = resps[trim_length:, :]
else:
path, txt, qnt = self.sample_utterance(spkr_name)
path, txt, qnt = self.sample_utterance(speaker_name)
# <s>[original text]</s><s>[new text]</s>
if naive:
@ -1471,7 +1348,7 @@ class Dataset(_Dataset):
# Duration prediction (<text><prompt> => len(<resp>))
elif task == "len":
proms = self.sample_prompts(spkr_name, reference=path)
proms = self.sample_prompts(speaker_name, utterance_name)
elif task in ["phn", "un-phn"]:
proms = []
@ -1503,10 +1380,10 @@ class Dataset(_Dataset):
# target speech extraction ( <text><prom><resp + other resp> => <resp> )
elif task == "tse":
# sample a prompt
proms = self.sample_prompts(spkr_name, reference=path)
proms = self.sample_prompts(speaker_name, utterance_name)
# sample another speaker
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[spkr_name]))
_, __, other_resps = self.sample_utterance(self.sample_speakers(ignore=[speaker_name]))
# overlay the random speaker over the target audio
other_resps = merge_audio( resps, other_resps, scale=[1, random.uniform(0.5, 0.75)], device=cfg.dataset.reencode_device )
@ -1531,7 +1408,7 @@ class Dataset(_Dataset):
samples = []
for _ in range( 4 ):
sampled = self.sample_utterance(spkr_name, ignore=[s[0] for s in samples])
sampled = self.sample_utterance(speaker_name, ignore=[s[0] for s in samples])
samples.append( sampled )
pre_text, mid_text, post_text, edit_text = [ s[1][1:-1] for s in samples ]
@ -1611,8 +1488,8 @@ class Dataset(_Dataset):
return dict(
index=index,
path=Path(path),
spkr_name=spkr_name,
spkr_id=spkr_id,
speaker_name=speaker_name,
speaker_id=speaker_id,
task=task,
lang=lang,
tone=tone,
@ -1640,10 +1517,8 @@ class Dataset(_Dataset):
return len(self.sampler if self.sampler is not None else self) // self.batch_size
def __len__(self):
if self.sampler_type == "group":
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
if self.sampler_type == "speaker":
return min(len(self.spkrs), self._head or len(self.spkrs))
return min(len(self.speakers), self._head or len(self.speakers))
return min(len(self.paths), self._head or len(self.paths))
@ -1690,8 +1565,7 @@ def create_train_dataloader():
train_dl = _create_dataloader(train_dataset, training=True)
_logger.info(str(train_dataset.phone_symmap))
_logger.info(str(train_dataset.spkr_symmap))
_logger.info(str(train_dataset.spkr_group_symmap))
_logger.info(str(train_dataset.speaker_symmap))
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#duration (train): {str(train_dataset.duration)}.")
@ -1706,8 +1580,7 @@ def create_val_dataloader():
val_dl = _create_dataloader(val_dataset, training=False)
_logger.info(str(val_dataset.phone_symmap))
_logger.info(str(val_dataset.spkr_symmap))
_logger.info(str(val_dataset.spkr_group_symmap))
_logger.info(str(val_dataset.speaker_symmap))
_logger.info(f"#samples (val): {len(val_dataset)}.")
_logger.info(f"#duration (val): {str(val_dataset.duration)}.")
@ -1724,8 +1597,7 @@ def create_train_val_dataloader():
val_dl = _create_dataloader(val_dataset, training=False)
_logger.info(str(train_dataset.phone_symmap))
_logger.info(f'#speakers (train): {len(train_dataset.spkr_symmap)}')
_logger.info(f'#groups (train): {len(train_dataset.spkr_group_symmap)}')
_logger.info(f'#speakers (train): {len(train_dataset.speaker_symmap)}')
_logger.info(f"#samples (train): {len(train_dataset)}.")
_logger.info(f"#samples (val): {len(val_dataset)}.")
@ -2018,7 +1890,6 @@ if __name__ == "__main__":
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
@ -2044,31 +1915,18 @@ if __name__ == "__main__":
for i in range(len(v)):
_logger.info(f'{k}[{i}]: {v[i]}')
elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
train_dl, val_dl = create_train_val_dataloader()
dataset = train_dl.dataset
missing = []
symmap = get_phone_symmap()
for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
if dataset.sampler_type == "group":
spkr_group = dataset.spkr_groups[index]
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
spkr_name = dataset.spkr_samplers[spkr_group].sample()
spkr_id = dataset.spkr_symmap[spkr_name]
path = dataset.samplers[spkr_name].sample()
elif dataset.sampler_type == "speaker":
spkr_name = dataset.spkrs[index]
spkr_id = dataset.spkr_symmap[spkr_name]
path = dataset.samplers[spkr_name].sample()
spkr_group = dataset.get_speaker_group(path)
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
else:
path = dataset.paths[index]
spkr_name = dataset.get_speaker(path)
spkr_id = dataset.spkr_symmap[spkr_name]
spkr_group = dataset.get_speaker_group(path)
#spkr_group_id = dataset.spkr_group_symmap[spkr_group]
speaker_id, utterance_id = dataset.paths[index]
speaker_name = dataset.speakers[speaker_id]
speaker_keys = list(dataset.metadata[speaker_name].keys())
utterance_name = speaker_keys[utterance_id]
path = cfg.data_dir / speaker_name / utterance_name
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@ -2113,7 +1971,7 @@ if __name__ == "__main__":
index = 0
cfg.dataset.tasks_list = args.tasks.split(",")
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
train_dl, val_dl = create_train_val_dataloader()
batch = next(iter(train_dl))
for text, resps, proms, task in zip(batch["text"], batch["resps"], batch["proms"], batch["task"]):