diff --git a/vall_e/config.py b/vall_e/config.py
index 4ded6b4..d6e6252 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -881,25 +881,26 @@ class Config(BaseConfig):
supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"])
def set_audio_backend(self, audio_backend):
- cfg.audio_backend = audio_backend
+ self.audio_backend = audio_backend
+
audio_extension = None
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
- cfg.sample_rate = 24_000
- #cfg.model.resp_levels = 8
+ self.sample_rate = 24_000
+ #self.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac"
- cfg.sample_rate = 44_100
- #cfg.model.resp_levels = 9
- elif cfg.audio_backend == "audiodec":
+ self.sample_rate = 44_100
+ #self.model.resp_levels = 9
+ elif self.audio_backend == "audiodec":
audio_extension = ".dec"
- cfg.sample_rate = 48_000
- #cfg.model.resp_levels = 8 # ?
- elif cfg.audio_backend == "nemo":
+ self.sample_rate = 48_000
+ #self.model.resp_levels = 8 # ?
+ elif self.audio_backend == "nemo":
audio_extension = ".nem"
- cfg.sample_rate = 44_100
- #cfg.model.resp_levels = 8
- #cfg.model.audio_tokens = 1000
+ self.sample_rate = 44_100
+ #self.model.resp_levels = 8
+ #self.model.audio_tokens = 1000
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
diff --git a/vall_e/data.py b/vall_e/data.py
index bcf1747..92f078a 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -702,136 +702,75 @@ def _get_metadata_extension():
def _get_artifact_path(path):
return _replace_file_extension(path, _get_artifact_extension())
-_durations_map = {}
-_similar_map = {}
+def _get_path_key( type, dir, id ):
+ return f"/{type}/{_get_hdf5_path(dir)}/{id}" if cfg.dataset.use_hdf5 else str(dir / id)
-def _get_duration_map( type="training" ):
- return _durations_map[type] if type in _durations_map else {}
-
-def _get_similar_map( type="training" ):
- return _similar_map[type] if type in _similar_map else {}
-
-def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
+def _load_dataset_metadata(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None):
assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low."
+
+ # for now only ensure metadata-based path
+ assert cfg.dataset.use_metadata, "Metadata required."
if not dataset_hash_key:
dataset_hash_key = cfg.dataset.hash_key(sorted(dataset))
cached_dir = cfg.cache_dir / dataset_hash_key
-
- cached_durations_path = cached_dir / f"durations[{type}].json"
- cached_paths_path = cached_dir / f"dataloader[{type}].json"
- cached_similar_path = cached_dir / f"similar[{type}].json"
-
- # load the duration table first, since this is independent from the loaded paths
- if cached_durations_path.exists():
- _durations_map[type] = json_read( cached_durations_path )
- # load the similar paths table as well, since this is also independent
- if cached_similar_path.exists():
- _similar_map[type] = json_read( cached_similar_path )
-
- # load the cached valid paths (if we're requesting cache use)
- if cached_paths_path.exists() and cfg.dataset.cache:
- # to-do: automatic conversion between HDF5 formatted paths and on-disk paths
- return json_read( cached_paths_path )
+ cached_path = cached_dir / f"dataset[{type}].json"
- # deduce valid paths
- paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) }
+ if cached_path.exists() and cfg.dataset.cache:
+ return json_read( cached_path )
+
+ dataset_metadata = {}
+ def validate_utterance( id, entry ):
+ duration = entry.get('duration', 0)
+ in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
+
+ if cfg.dataset.validate and type == "training" and not in_bounds:
+ return False
+
+ if cfg.dataset.strict_validate:
+ if cfg.dataset.use_hdf5 and key(type, dir, id) not in cfg.hdf5:
+ return False
+
+ if not (cfg.data_dir / dir / id).with_suffix(_get_artifact_extension()).exists():
+ return False
+
+ return True
+
+ def process_utterance( id, entry, metadata_keys=None ):
+ duration = entry.get('duration', 0)
+ similar = entry.get('similar', None)
+ # store duration length, and similar key name (because indices might shift)
+ return [duration, ([ metadata_keys[idx] for idx in similar ] if similar and metadata_keys else [])]
+
+ for dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent ):
+ metadata_path = cfg.metadata_dir / f'{dir}.json'
+ if not metadata_path.exists():
+ continue
+
+ # to-do: make json_read handle when it actually can't read the file......
+ try:
+ metadata = json_read( metadata_path )
+ except Exception as e:
+ continue
+
+ speaker = str(dir)
+ metadata_keys = list(metadata.keys())
+ dataset_metadata[speaker] = { id: process_utterance( id, entry, metadata_keys ) for id, entry in metadata.items() if validate_utterance( id, entry ) }
+
+ # remap strings to indices
+ remapped_indices = { k: i for i, k in enumerate(dataset_metadata[speaker].keys()) }
+ for id, (duration, similars) in dataset_metadata[speaker].items():
+ dataset_metadata[speaker][id][1] = [ remapped_indices[k] for k in similars if k in remapped_indices ]
# and write if global leader (to avoid other processes writing to the same file at once)
if is_global_leader():
if not cached_dir.exists():
cached_dir.mkdir(parents=True, exist_ok=True)
- json_write( _similar_map[type], cached_similar_path, truncate=True )
- json_write( _durations_map[type], cached_durations_path, truncate=True )
- json_write( paths, cached_paths_path, truncate=True )
+ json_write( dataset_metadata, cached_path, truncate=True )
- return paths
-
-def _load_paths_from_metadata(group_name, type="training", validate=False):
- data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name
-
- _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions
-
- def key( id, entry=None ):
- return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else str(data_dir / id)
-
- metadata_path = cfg.metadata_dir / f'{group_name}.json'
- metadata = {}
-
- if cfg.dataset.use_metadata and metadata_path.exists():
- try:
- metadata = json_read( metadata_path )
- except Exception as e:
- return {}
-
- if len(metadata) == 0:
- return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
-
- # this might be slow
- def _exists( id, entry ):
- if not cfg.dataset.strict_validate:
- return True
-
- if cfg.dataset.use_hdf5:
- return key(id, entry) in cfg.hdf5
-
- return (data_dir / id).with_suffix(_get_artifact_extension()).exists()
-
- metadata_keys = list(metadata.keys())
- def _validate( id, entry ):
- phones = entry.get('phones', 0)
- duration = entry.get('duration', 0)
- similar = entry.get('similar', None)
-
- k = key(id, entry)
-
- # add to duration bucket
- if type not in _durations_map:
- _durations_map[type] = {}
- _durations_map[type][k] = duration
-
- # add to similar bucket
- if type not in _similar_map:
- _similar_map[type] = {}
- _similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None
-
- if not validate:
- return True
-
- in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
- if in_bounds and not _exists( id, entry ):
- return False
-
- return in_bounds
-
- return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ]
-
-
-def _get_hdf5_path(path):
- # to-do: better validation
- return str(path)
-
-def _get_hdf5_paths( data_dir, type="training", validate=False ):
- data_dir = str(data_dir)
-
- key = f"/{type}/{_get_hdf5_path(data_dir)}"
-
- def _validate( id, entry ):
- phones = entry.attrs['phonemes']
- duration = entry.attrs['duration']
-
- if type not in _durations_map:
- _durations_map[type] = {}
- _durations_map[type][f"{key}/{id}"] = float(duration)
-
- if not validate:
- return True
-
- return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration
-
- return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
+ return dataset_metadata
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
if isinstance(path, str):
@@ -878,7 +817,7 @@ class Dataset(_Dataset):
self,
phone_symmap=None,
training=False,
- extra_paths_by_spkr_name: dict[str, list] = {},
+ extra_paths_by_speaker_name: dict[str, list] = {},
):
super().__init__()
@@ -910,40 +849,31 @@ class Dataset(_Dataset):
if self.sampler_order == "duration" and self.sampler_type != "path":
raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.')
- # dict of paths keyed by speaker names
- self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
- self.duration_map = _get_duration_map( self.dataset_type )
-
- # cull speakers if they do not have enough utterances (or cull speakers with too many utternaces)
- if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0:
- keys = list(self.paths_by_spkr_name.keys())
- for key in keys:
- if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances:
- del self.paths_by_spkr_name[key]
- continue
-
- # slice away extraneous utterances
- if cfg.dataset.max_utterances:
- self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances]
-
- # flatten paths
- self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
+ # dict that maps [speaker][id] to (duration, similar utterances)
+ self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key)
+ # cull speakers with too little utterances
+ for speaker in self.metadata.keys():
+ utterances = len(self.metadata[speaker])
+ if utterances < cfg.dataset.min_utterances:
+ del self.metadata[speaker]
+
+ self.paths = []
+ self.speakers = list(self.metadata.keys())
+ for speaker_id, speaker in enumerate(self.speakers):
+ utterances = len(self.metadata[speaker])
+ self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ]
+
# split dataset accordingly per GPU
if cfg.distributed and self.training:
self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ]
- # recreate paths_by_spkr_name
- self.paths_by_spkr_name = {}
- for path in self.paths:
- name = cfg.get_spkr( Path(path) )
- if name not in self.paths_by_spkr_name:
- self.paths_by_spkr_name[name] = []
- self.paths_by_spkr_name[name].append( path )
-
# store in corresponding bucket
- for path in self.paths:
- duration = self.duration_map[path]
+ for (speaker_id, utterance_id) in self.paths:
+ speaker = self.speakers[speaker_id]
+ utterance = list(self.metadata[speaker].keys())[utterance_id]
+
+ duration, _ = self.metadata[speaker][utterance]
self.duration += duration
# only calc duration if we're going to order by duration
@@ -953,7 +883,7 @@ class Dataset(_Dataset):
bucket = int(round(duration))
if bucket not in self.duration_buckets:
self.duration_buckets[bucket] = []
- self.duration_buckets[bucket].append( ( Path(path), duration ) )
+ self.duration_buckets[bucket].append( ( (speaker_id, utterance_id), duration ) )
# sort by duration
if self.sampler_order == "duration":
@@ -964,43 +894,28 @@ class Dataset(_Dataset):
# sort and interleave
for bucket in self.duration_buckets:
# sort by duration
- self.duration_buckets[bucket].sort( key=lambda x: x[1] )
+ self.duration_buckets[bucket].sort( key=lambda x: x[-1] )
# split to retain tuples
flattened[bucket] = self.duration_buckets[bucket]
# replace with path
flattened[bucket] = [ x[0] for x in flattened[bucket] ]
# flatten by paths
- flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)]
+ flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])]
# flatten paths
self.paths = list(itertools.chain.from_iterable(flattened.values()))
elif self.sampler_order == "random":
random.shuffle( self.paths )
else:
# just interleave
- self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
-
- # dict of speakers keyed by speaker group
- self.spkrs_by_spkr_group = {}
- for data_dir in self.dataset:
- spkr = cfg.get_spkr( data_dir / "dummy" )
- spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
-
- if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
- continue
-
- if spkr_group not in self.spkrs_by_spkr_group:
- self.spkrs_by_spkr_group[spkr_group] = []
-
- self.spkrs_by_spkr_group[spkr_group].append( spkr )
-
- self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
+ self.paths = [*_interleaved_reorder(self.paths, lambda x: x[0])]
+ """
self.noise_paths = _load_paths(cfg.dataset.noise, "noise")
self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values()))
+ """
self.phone_symmap = phone_symmap or self._get_phone_symmap()
- self.spkr_symmap = self._get_spkr_symmap()
- self.spkr_group_symmap = self._get_spkr_group_symmap()
+ self.speaker_symmap = self._get_speaker_symmap()
self.lang_symmap = self._get_lang_symmap()
self.tone_symmap = self._get_tone_symmap()
self.task_symmap = self._get_task_symmap()
@@ -1022,27 +937,16 @@ class Dataset(_Dataset):
if len(self.paths) == 0:
raise ValueError(f"No valid path is found for {self.dataset_type}")
- if self.sampler_type == "path" and self.training:
- if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
- self.sampler = BatchedOrderedSampler(
- self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
- max_duration=cfg.dataset.sample_max_duration_batch,
- max_batch_size=self.batch_size,
- shuffle=self.sampler_shuffle,
- )
- self.batch_size = 1
- else:
- self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
- self.samplers = {}
- self.spkr_samplers = {}
+ if self.training and self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
+ self.sampler = BatchedOrderedSampler(
+ self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways
+ max_duration=cfg.dataset.sample_max_duration_batch,
+ max_batch_size=self.batch_size,
+ shuffle=self.sampler_shuffle,
+ )
+ self.batch_size = 1
else:
- self.sampler = RandomSampler( len(self) )
- self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() }
- self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() }
-
- # dereference buckets
- self.duration_map = None
- self.duration_buckets = None
+ self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) )
self.load_state_dict()
@@ -1053,13 +957,13 @@ class Dataset(_Dataset):
def get_speaker(self, path):
if isinstance(path, str):
path = Path(path)
- res = cfg.get_spkr(path)
+ res = cfg.get_speaker(path)
return res
def get_speaker_group(self, path):
if isinstance(path, str):
path = Path(path)
- res = cfg.get_spkr_group(path)
+ res = cfg.get_speaker_group(path)
return res
# this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times)
@@ -1071,10 +975,6 @@ class Dataset(_Dataset):
return lang.lower()
- @cached_property
- def spkrs(self):
- return sorted({self.get_speaker(path) for path in self.paths})
-
@cached_property
def tasks(self):
if not self.training:
@@ -1088,13 +988,16 @@ class Dataset(_Dataset):
if not path.parent.exists():
path.parent.mkdir(parents=True, exist_ok=True)
+ state_dict = self.sampler.get_state()
+ """
if self.sampler_type == "path":
state_dict = self.sampler.get_state()
else:
state_dict = {
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
- "spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
+ "speaker_samplers": { name: sampler.get_state() for name, sampler in self.speaker_samplers.items() },
}
+ """
if "dataset_hash_key" not in state_dict:
state_dict["dataset_hash_key"] = self.dataset_hash_key
@@ -1117,6 +1020,8 @@ class Dataset(_Dataset):
_logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.')
return
+ state_dict = self.sampler.set_state(state_dict)
+ """
if self.sampler_type == "path":
state_dict = self.sampler.set_state(state_dict)
else:
@@ -1125,19 +1030,17 @@ class Dataset(_Dataset):
continue
self.samplers[name].set_state( sampler )
- for name, sampler in state_dict["spkr_samplers"].items():
- if name not in self.spkr_samplers:
+ for name, sampler in state_dict["speaker_samplers"].items():
+ if name not in self.speaker_samplers:
continue
- self.spkr_samplers[name].set_state( sampler )
+ self.speaker_samplers[name].set_state( sampler )
+ """
def _get_phone_symmap(self):
return get_phone_symmap()
- def _get_spkr_symmap(self):
- return {s: i for i, s in enumerate(self.spkrs)}
-
- def _get_spkr_group_symmap(self):
- return {s: i for i, s in enumerate(self.spkr_groups)}
+ def _get_speaker_symmap(self):
+ return {s: i for i, s in enumerate(self.speakers)}
def _get_lang_symmap(self):
return get_lang_symmap()
@@ -1159,17 +1062,19 @@ class Dataset(_Dataset):
return qnt
def sample_speakers(self, ignore=[]):
- choices = set(self.spkrs) - set(ignore)
+ choices = set(self.speakers) - set(ignore)
return random.choice([*choices])
- def sample_utterance(self, spkr_name, ignore=[]):
- choices = [*(set(self.paths_by_spkr_name[spkr_name]) - set(ignore))]
+ def sample_utterance(self, speaker_name, ignore=[]):
+ choices = [*(set(self.metadata[speaker_name].keys()) - set(ignore))]
if len(choices) == 0:
return None, None, None
- path = random.choice(choices)
+ utterance_id = random.choice(choices)
+ utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
+ path = cfg.data_dir / speaker_name / utterance_name
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@@ -1201,13 +1106,11 @@ class Dataset(_Dataset):
return path, text, resps
# icky slop
- def get_similar_utterance(self, path, offset=None ):
+ def get_similar_utterance(self, speaker_name, utterance_name, offset=None ):
if offset is None:
offset = cfg.dataset.prompt_similar_top_k_offset
- root = Path( *path.parts[:-1] )
- reference = path.name
- similars = _similar_map[self.dataset_type].get(str(path), None)
+ _, similars = self.metadata[speaker_name][utterance_name]
if not similars:
return None
@@ -1223,41 +1126,28 @@ class Dataset(_Dataset):
if offset_end >= len( similars ):
return None
+ utterance_keys = list(self.metadata[speaker_name].keys())
if cfg.dataset.prompt_similar_top_k > 1:
- name = random.choice( similars[offset:offset_end] )
+ index = random.choice( similars[offset:offset_end] )
else:
- name = similars[offset]
+ index = similars[offset]
- path = root / name
+ return utterance_keys[index]
- if cfg.dataset.use_hdf5:
- key = _get_hdf5_path(path)
- if key not in cfg.hdf5[key]:
- return None
- elif not path.exists():
- return None
-
- return path
-
- def sample_prompts(self, spkr_name, reference, should_trim=True):
+ def sample_prompts(self, speaker_name, utterance_name=None, should_trim=True):
# return no prompt if explicitly requested for who knows why
# or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them)
- if len(self.paths_by_spkr_name[spkr_name]) <= 1:
+ if len(self.metadata[speaker_name]) <= 1:
return None
prom_list = []
- choices = set(self.paths_by_spkr_name[spkr_name]) - {reference}
+ choices = set(self.metadata[speaker_name].keys()) - {utterance_name}
choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
if len(choices) == 0:
- choices = [*set(self.paths_by_spkr_name[spkr_name])]
- """
- raise ValueError(
- f"Failed to find another different utterance for {spkr_name}."
- )
- """
+ choices = [*set(self.metadata[speaker_name].keys())]
if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0:
should_trim = False
@@ -1271,17 +1161,18 @@ class Dataset(_Dataset):
for _ in range(cfg.dataset.prompt_max_samples):
# yuck
- path = None
- if reference is not None:
+ sampled_utterance = None
+ if utterance_name is not None:
if random.random() < cfg.dataset.prompt_similar_p:
try:
- path = self.get_similar_utterance( reference, offset = len(prom_list) )
+ sampled_utterance = self.get_similar_utterance( speaker_name, utterance_name, offset = len(prom_list) )
except Exception as e:
- path = None
+ sampled_utterance = None
- if not path:
- path = random.choice(choices)
+ if not sampled_utterance:
+ sampled_utterance = random.choice(choices)
+ path = cfg.data_dir / speaker_name / sampled_utterance
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
@@ -1294,6 +1185,7 @@ class Dataset(_Dataset):
except Exception as e:
_logger.warning(f'Failed to load artifact: {path} ({e})')
path = None
+ continue
if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
@@ -1321,27 +1213,10 @@ class Dataset(_Dataset):
bos_id, space_id, eos_id = self.empty_text
- if self.sampler_type == "group":
- spkr_group = self.spkr_groups[index]
- #spkr_group_id = self.spkr_group_symmap[spkr_group]
- spkr_name = self.spkr_samplers[spkr_group].sample()
- spkr_id = self.spkr_symmap[spkr_name]
- path = self.samplers[spkr_name].sample()
- elif self.sampler_type == "speaker":
- spkr_name = self.spkrs[index]
- spkr_id = self.spkr_symmap[spkr_name]
- path = self.samplers[spkr_name].sample()
- spkr_group = self.get_speaker_group(path)
- #spkr_group_id = self.spkr_group_symmap[spkr_group]
- else:
- path = self.paths[index]
- spkr_name = self.get_speaker(path)
- spkr_id = self.spkr_symmap[spkr_name]
- spkr_group = self.get_speaker_group(path)
- #spkr_group_id = self.spkr_group_symmap[spkr_group]
-
- if not isinstance( path, Path ):
- path = Path( path )
+ speaker_id, utterance_id = self.paths[index]
+ speaker_name = self.speakers[speaker_id]
+ utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
+ path = cfg.data_dir / speaker_name / utterance_name
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
@@ -1378,7 +1253,7 @@ class Dataset(_Dataset):
tone = metadata["tone"] if "tone" in metadata else None
text_string = metadata["text"] if "text" in metadata else None
- lang = self.get_language(spkr_group) if not lang else lang.lower()
+ lang = lang.lower() if lang else "en"
raw_text = torch.tensor(text_tokenize(text_string)).to(torch.int16) if text_string else None
@@ -1398,7 +1273,7 @@ class Dataset(_Dataset):
if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p:
ignore_paths = []
for _ in range( 1, cfg.dataset.resps_max_samples ):
- path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths)
+ path, txt, qnt = self.sample_utterance(speaker_name, ignore=ignore_paths)
ignore_paths.append(path)
# [original text][new text]
@@ -1412,7 +1287,7 @@ class Dataset(_Dataset):
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
-
+
task = random.choice(self.tasks)
if task not in self.task_symmap:
@@ -1420,8 +1295,10 @@ class Dataset(_Dataset):
# Base TTS ( => )
if task == "tts":
- proms = self.sample_prompts(spkr_name, reference=path)
+ proms = self.sample_prompts(speaker_name, utterance_name)
+ """
+ proms = self.sample_prompts(speaker_name, reference=path)
if random.random() < cfg.dataset.prompt_inject_noise_p:
# sample random noise
noise = self.sample_noise()
@@ -1429,7 +1306,7 @@ class Dataset(_Dataset):
noise = repeat_extend_audio(noise, proms.shape[0])
# create the input prompt by merging the target audio with the noise
proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
-
+ """
# VALL-E Continuous ( => )
# (this could just be sampled as