diff --git a/vall_e/config.py b/vall_e/config.py index 4ded6b4..d6e6252 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -881,25 +881,26 @@ class Config(BaseConfig): supported_weights_formats: list[str] = field(default_factory=lambda: ["sft", "safetensors", "pt", "pth"]) def set_audio_backend(self, audio_backend): - cfg.audio_backend = audio_backend + self.audio_backend = audio_backend + audio_extension = None if audio_backend in ["encodec", "vocos"]: audio_extension = ".enc" - cfg.sample_rate = 24_000 - #cfg.model.resp_levels = 8 + self.sample_rate = 24_000 + #self.model.resp_levels = 8 elif audio_backend == "dac": audio_extension = ".dac" - cfg.sample_rate = 44_100 - #cfg.model.resp_levels = 9 - elif cfg.audio_backend == "audiodec": + self.sample_rate = 44_100 + #self.model.resp_levels = 9 + elif self.audio_backend == "audiodec": audio_extension = ".dec" - cfg.sample_rate = 48_000 - #cfg.model.resp_levels = 8 # ? - elif cfg.audio_backend == "nemo": + self.sample_rate = 48_000 + #self.model.resp_levels = 8 # ? + elif self.audio_backend == "nemo": audio_extension = ".nem" - cfg.sample_rate = 44_100 - #cfg.model.resp_levels = 8 - #cfg.model.audio_tokens = 1000 + self.sample_rate = 44_100 + #self.model.resp_levels = 8 + #self.model.audio_tokens = 1000 else: raise Exception(f"Unknown audio backend: {audio_backend}") diff --git a/vall_e/data.py b/vall_e/data.py index bcf1747..92f078a 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -702,136 +702,75 @@ def _get_metadata_extension(): def _get_artifact_path(path): return _replace_file_extension(path, _get_artifact_extension()) -_durations_map = {} -_similar_map = {} +def _get_path_key( type, dir, id ): + return f"/{type}/{_get_hdf5_path(dir)}/{id}" if cfg.dataset.use_hdf5 else str(dir / id) -def _get_duration_map( type="training" ): - return _durations_map[type] if type in _durations_map else {} - -def _get_similar_map( type="training" ): - return _similar_map[type] if type in _similar_map else {} - -def _load_paths(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None): +def _load_dataset_metadata(dataset, type="training", silent=not is_global_leader(), dataset_hash_key=None): assert cfg.dataset.min_duration >= 1.0, "Minimum duration too low." + + # for now only ensure metadata-based path + assert cfg.dataset.use_metadata, "Metadata required." if not dataset_hash_key: dataset_hash_key = cfg.dataset.hash_key(sorted(dataset)) cached_dir = cfg.cache_dir / dataset_hash_key - - cached_durations_path = cached_dir / f"durations[{type}].json" - cached_paths_path = cached_dir / f"dataloader[{type}].json" - cached_similar_path = cached_dir / f"similar[{type}].json" - - # load the duration table first, since this is independent from the loaded paths - if cached_durations_path.exists(): - _durations_map[type] = json_read( cached_durations_path ) - # load the similar paths table as well, since this is also independent - if cached_similar_path.exists(): - _similar_map[type] = json_read( cached_similar_path ) - - # load the cached valid paths (if we're requesting cache use) - if cached_paths_path.exists() and cfg.dataset.cache: - # to-do: automatic conversion between HDF5 formatted paths and on-disk paths - return json_read( cached_paths_path ) + cached_path = cached_dir / f"dataset[{type}].json" - # deduce valid paths - paths = { cfg.get_spkr( cfg.data_dir / data_dir / "dummy" ): _load_paths_from_metadata( data_dir, type=type, validate=cfg.dataset.validate and type == "training" ) for data_dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent) } + if cached_path.exists() and cfg.dataset.cache: + return json_read( cached_path ) + + dataset_metadata = {} + def validate_utterance( id, entry ): + duration = entry.get('duration', 0) + in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration + + if cfg.dataset.validate and type == "training" and not in_bounds: + return False + + if cfg.dataset.strict_validate: + if cfg.dataset.use_hdf5 and key(type, dir, id) not in cfg.hdf5: + return False + + if not (cfg.data_dir / dir / id).with_suffix(_get_artifact_extension()).exists(): + return False + + return True + + def process_utterance( id, entry, metadata_keys=None ): + duration = entry.get('duration', 0) + similar = entry.get('similar', None) + # store duration length, and similar key name (because indices might shift) + return [duration, ([ metadata_keys[idx] for idx in similar ] if similar and metadata_keys else [])] + + for dir in tqdm(dataset, desc=f"Parsing dataset: {type}", disable=silent ): + metadata_path = cfg.metadata_dir / f'{dir}.json' + if not metadata_path.exists(): + continue + + # to-do: make json_read handle when it actually can't read the file...... + try: + metadata = json_read( metadata_path ) + except Exception as e: + continue + + speaker = str(dir) + metadata_keys = list(metadata.keys()) + dataset_metadata[speaker] = { id: process_utterance( id, entry, metadata_keys ) for id, entry in metadata.items() if validate_utterance( id, entry ) } + + # remap strings to indices + remapped_indices = { k: i for i, k in enumerate(dataset_metadata[speaker].keys()) } + for id, (duration, similars) in dataset_metadata[speaker].items(): + dataset_metadata[speaker][id][1] = [ remapped_indices[k] for k in similars if k in remapped_indices ] # and write if global leader (to avoid other processes writing to the same file at once) if is_global_leader(): if not cached_dir.exists(): cached_dir.mkdir(parents=True, exist_ok=True) - json_write( _similar_map[type], cached_similar_path, truncate=True ) - json_write( _durations_map[type], cached_durations_path, truncate=True ) - json_write( paths, cached_paths_path, truncate=True ) + json_write( dataset_metadata, cached_path, truncate=True ) - return paths - -def _load_paths_from_metadata(group_name, type="training", validate=False): - data_dir = group_name if cfg.dataset.use_hdf5 else cfg.data_dir / group_name - - _fn = _get_hdf5_paths if cfg.dataset.use_hdf5 else _get_paths_of_extensions - - def key( id, entry=None ): - return f"/{type}/{_get_hdf5_path(data_dir)}/{id}" if cfg.dataset.use_hdf5 else str(data_dir / id) - - metadata_path = cfg.metadata_dir / f'{group_name}.json' - metadata = {} - - if cfg.dataset.use_metadata and metadata_path.exists(): - try: - metadata = json_read( metadata_path ) - except Exception as e: - return {} - - if len(metadata) == 0: - return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate ) - - # this might be slow - def _exists( id, entry ): - if not cfg.dataset.strict_validate: - return True - - if cfg.dataset.use_hdf5: - return key(id, entry) in cfg.hdf5 - - return (data_dir / id).with_suffix(_get_artifact_extension()).exists() - - metadata_keys = list(metadata.keys()) - def _validate( id, entry ): - phones = entry.get('phones', 0) - duration = entry.get('duration', 0) - similar = entry.get('similar', None) - - k = key(id, entry) - - # add to duration bucket - if type not in _durations_map: - _durations_map[type] = {} - _durations_map[type][k] = duration - - # add to similar bucket - if type not in _similar_map: - _similar_map[type] = {} - _similar_map[type][k] = [ metadata_keys[idx] for idx in similar ] if similar else None - - if not validate: - return True - - in_bounds = cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration - if in_bounds and not _exists( id, entry ): - return False - - return in_bounds - - return [ key(id, entry) for id, entry in metadata.items() if _validate(id, entry) ] - - -def _get_hdf5_path(path): - # to-do: better validation - return str(path) - -def _get_hdf5_paths( data_dir, type="training", validate=False ): - data_dir = str(data_dir) - - key = f"/{type}/{_get_hdf5_path(data_dir)}" - - def _validate( id, entry ): - phones = entry.attrs['phonemes'] - duration = entry.attrs['duration'] - - if type not in _durations_map: - _durations_map[type] = {} - _durations_map[type][f"{key}/{id}"] = float(duration) - - if not validate: - return True - - return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration - - return [ f"{key}/{id}" for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else [] + return dataset_metadata def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ): if isinstance(path, str): @@ -878,7 +817,7 @@ class Dataset(_Dataset): self, phone_symmap=None, training=False, - extra_paths_by_spkr_name: dict[str, list] = {}, + extra_paths_by_speaker_name: dict[str, list] = {}, ): super().__init__() @@ -910,40 +849,31 @@ class Dataset(_Dataset): if self.sampler_order == "duration" and self.sampler_type != "path": raise Exception(f'Requesting sample_type={self.sampler_type} with sample_order={self.sampler_order}, yet combination will not give expected results.') - # dict of paths keyed by speaker names - self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key) - self.duration_map = _get_duration_map( self.dataset_type ) - - # cull speakers if they do not have enough utterances (or cull speakers with too many utternaces) - if cfg.dataset.min_utterances > 0 or cfg.dataset.max_utterances > 0: - keys = list(self.paths_by_spkr_name.keys()) - for key in keys: - if len(self.paths_by_spkr_name[key]) < cfg.dataset.min_utterances: - del self.paths_by_spkr_name[key] - continue - - # slice away extraneous utterances - if cfg.dataset.max_utterances: - self.paths_by_spkr_name[key] = self.paths_by_spkr_name[key][:cfg.dataset.max_utterances] - - # flatten paths - self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values())) + # dict that maps [speaker][id] to (duration, similar utterances) + self.metadata = _load_dataset_metadata(self.dataset, self.dataset_type, dataset_hash_key=self.dataset_hash_key) + # cull speakers with too little utterances + for speaker in self.metadata.keys(): + utterances = len(self.metadata[speaker]) + if utterances < cfg.dataset.min_utterances: + del self.metadata[speaker] + + self.paths = [] + self.speakers = list(self.metadata.keys()) + for speaker_id, speaker in enumerate(self.speakers): + utterances = len(self.metadata[speaker]) + self.paths += [ (speaker_id, utterance_id) for utterance_id in range( utterances ) ] + # split dataset accordingly per GPU if cfg.distributed and self.training: self.paths = [ path for i, path in enumerate(self.paths) if i % world_size() == 0 ] - # recreate paths_by_spkr_name - self.paths_by_spkr_name = {} - for path in self.paths: - name = cfg.get_spkr( Path(path) ) - if name not in self.paths_by_spkr_name: - self.paths_by_spkr_name[name] = [] - self.paths_by_spkr_name[name].append( path ) - # store in corresponding bucket - for path in self.paths: - duration = self.duration_map[path] + for (speaker_id, utterance_id) in self.paths: + speaker = self.speakers[speaker_id] + utterance = list(self.metadata[speaker].keys())[utterance_id] + + duration, _ = self.metadata[speaker][utterance] self.duration += duration # only calc duration if we're going to order by duration @@ -953,7 +883,7 @@ class Dataset(_Dataset): bucket = int(round(duration)) if bucket not in self.duration_buckets: self.duration_buckets[bucket] = [] - self.duration_buckets[bucket].append( ( Path(path), duration ) ) + self.duration_buckets[bucket].append( ( (speaker_id, utterance_id), duration ) ) # sort by duration if self.sampler_order == "duration": @@ -964,43 +894,28 @@ class Dataset(_Dataset): # sort and interleave for bucket in self.duration_buckets: # sort by duration - self.duration_buckets[bucket].sort( key=lambda x: x[1] ) + self.duration_buckets[bucket].sort( key=lambda x: x[-1] ) # split to retain tuples flattened[bucket] = self.duration_buckets[bucket] # replace with path flattened[bucket] = [ x[0] for x in flattened[bucket] ] # flatten by paths - flattened[bucket] = [*_interleaved_reorder(flattened[bucket], self.get_speaker)] + flattened[bucket] = [*_interleaved_reorder(flattened[bucket], lambda x: x[0])] # flatten paths self.paths = list(itertools.chain.from_iterable(flattened.values())) elif self.sampler_order == "random": random.shuffle( self.paths ) else: # just interleave - self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)] - - # dict of speakers keyed by speaker group - self.spkrs_by_spkr_group = {} - for data_dir in self.dataset: - spkr = cfg.get_spkr( data_dir / "dummy" ) - spkr_group = cfg.get_spkr_group( data_dir / "dummy" ) - - if spkr not in self.paths_by_spkr_name or len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances: - continue - - if spkr_group not in self.spkrs_by_spkr_group: - self.spkrs_by_spkr_group[spkr_group] = [] - - self.spkrs_by_spkr_group[spkr_group].append( spkr ) - - self.spkr_groups = list(self.spkrs_by_spkr_group.keys()) + self.paths = [*_interleaved_reorder(self.paths, lambda x: x[0])] + """ self.noise_paths = _load_paths(cfg.dataset.noise, "noise") self.noise_paths = list(itertools.chain.from_iterable(self.noise_paths.values())) + """ self.phone_symmap = phone_symmap or self._get_phone_symmap() - self.spkr_symmap = self._get_spkr_symmap() - self.spkr_group_symmap = self._get_spkr_group_symmap() + self.speaker_symmap = self._get_speaker_symmap() self.lang_symmap = self._get_lang_symmap() self.tone_symmap = self._get_tone_symmap() self.task_symmap = self._get_task_symmap() @@ -1022,27 +937,16 @@ class Dataset(_Dataset): if len(self.paths) == 0: raise ValueError(f"No valid path is found for {self.dataset_type}") - if self.sampler_type == "path" and self.training: - if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: - self.sampler = BatchedOrderedSampler( - self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways - max_duration=cfg.dataset.sample_max_duration_batch, - max_batch_size=self.batch_size, - shuffle=self.sampler_shuffle, - ) - self.batch_size = 1 - else: - self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) ) - self.samplers = {} - self.spkr_samplers = {} + if self.training and self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: + self.sampler = BatchedOrderedSampler( + self.duration_buckets if not self.sampler_state_dict_path.exists() else {}, # pass nothing if we're just going to load from a state anyways + max_duration=cfg.dataset.sample_max_duration_batch, + max_batch_size=self.batch_size, + shuffle=self.sampler_shuffle, + ) + self.batch_size = 1 else: - self.sampler = RandomSampler( len(self) ) - self.samplers = { name: PoolSampler( paths, keep_all=True, shuffle=self.sampler_shuffle ) for name, paths in self.paths_by_spkr_name.items() } - self.spkr_samplers = { name: PoolSampler( [*set(speakers)], keep_all=True, shuffle=self.sampler_shuffle ) for name, speakers in self.spkrs_by_spkr_group.items() } - - # dereference buckets - self.duration_map = None - self.duration_buckets = None + self.sampler = OrderedSampler( len(self) ) if not self.sampler_shuffle else RandomSampler( len(self) ) self.load_state_dict() @@ -1053,13 +957,13 @@ class Dataset(_Dataset): def get_speaker(self, path): if isinstance(path, str): path = Path(path) - res = cfg.get_spkr(path) + res = cfg.get_speaker(path) return res def get_speaker_group(self, path): if isinstance(path, str): path = Path(path) - res = cfg.get_spkr_group(path) + res = cfg.get_speaker_group(path) return res # this isn't really necessary since our data/metadata contains markers for languages, but this is still in in-case it's needed to force a language setting (for example, whisperX's lang isn't that accurate at times) @@ -1071,10 +975,6 @@ class Dataset(_Dataset): return lang.lower() - @cached_property - def spkrs(self): - return sorted({self.get_speaker(path) for path in self.paths}) - @cached_property def tasks(self): if not self.training: @@ -1088,13 +988,16 @@ class Dataset(_Dataset): if not path.parent.exists(): path.parent.mkdir(parents=True, exist_ok=True) + state_dict = self.sampler.get_state() + """ if self.sampler_type == "path": state_dict = self.sampler.get_state() else: state_dict = { "samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() }, - "spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() }, + "speaker_samplers": { name: sampler.get_state() for name, sampler in self.speaker_samplers.items() }, } + """ if "dataset_hash_key" not in state_dict: state_dict["dataset_hash_key"] = self.dataset_hash_key @@ -1117,6 +1020,8 @@ class Dataset(_Dataset): _logger.warning(f'Mismatched dataset hash key for {self.dataset_type} dataloader, ignoring loading of state dict.') return + state_dict = self.sampler.set_state(state_dict) + """ if self.sampler_type == "path": state_dict = self.sampler.set_state(state_dict) else: @@ -1125,19 +1030,17 @@ class Dataset(_Dataset): continue self.samplers[name].set_state( sampler ) - for name, sampler in state_dict["spkr_samplers"].items(): - if name not in self.spkr_samplers: + for name, sampler in state_dict["speaker_samplers"].items(): + if name not in self.speaker_samplers: continue - self.spkr_samplers[name].set_state( sampler ) + self.speaker_samplers[name].set_state( sampler ) + """ def _get_phone_symmap(self): return get_phone_symmap() - def _get_spkr_symmap(self): - return {s: i for i, s in enumerate(self.spkrs)} - - def _get_spkr_group_symmap(self): - return {s: i for i, s in enumerate(self.spkr_groups)} + def _get_speaker_symmap(self): + return {s: i for i, s in enumerate(self.speakers)} def _get_lang_symmap(self): return get_lang_symmap() @@ -1159,17 +1062,19 @@ class Dataset(_Dataset): return qnt def sample_speakers(self, ignore=[]): - choices = set(self.spkrs) - set(ignore) + choices = set(self.speakers) - set(ignore) return random.choice([*choices]) - def sample_utterance(self, spkr_name, ignore=[]): - choices = [*(set(self.paths_by_spkr_name[spkr_name]) - set(ignore))] + def sample_utterance(self, speaker_name, ignore=[]): + choices = [*(set(self.metadata[speaker_name].keys()) - set(ignore))] if len(choices) == 0: return None, None, None - path = random.choice(choices) + utterance_id = random.choice(choices) + utterance_name = list(self.metadata[speaker_name].keys())[utterance_id] + path = cfg.data_dir / speaker_name / utterance_name if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) @@ -1201,13 +1106,11 @@ class Dataset(_Dataset): return path, text, resps # icky slop - def get_similar_utterance(self, path, offset=None ): + def get_similar_utterance(self, speaker_name, utterance_name, offset=None ): if offset is None: offset = cfg.dataset.prompt_similar_top_k_offset - root = Path( *path.parts[:-1] ) - reference = path.name - similars = _similar_map[self.dataset_type].get(str(path), None) + _, similars = self.metadata[speaker_name][utterance_name] if not similars: return None @@ -1223,41 +1126,28 @@ class Dataset(_Dataset): if offset_end >= len( similars ): return None + utterance_keys = list(self.metadata[speaker_name].keys()) if cfg.dataset.prompt_similar_top_k > 1: - name = random.choice( similars[offset:offset_end] ) + index = random.choice( similars[offset:offset_end] ) else: - name = similars[offset] + index = similars[offset] - path = root / name + return utterance_keys[index] - if cfg.dataset.use_hdf5: - key = _get_hdf5_path(path) - if key not in cfg.hdf5[key]: - return None - elif not path.exists(): - return None - - return path - - def sample_prompts(self, spkr_name, reference, should_trim=True): + def sample_prompts(self, speaker_name, utterance_name=None, should_trim=True): # return no prompt if explicitly requested for who knows why # or if there's no other speakers to sample from (Emilia has a lot of singleton speakers, but I still want to make use of them) - if len(self.paths_by_spkr_name[spkr_name]) <= 1: + if len(self.metadata[speaker_name]) <= 1: return None prom_list = [] - choices = set(self.paths_by_spkr_name[spkr_name]) - {reference} + choices = set(self.metadata[speaker_name].keys()) - {utterance_name} choices = [*choices] # no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step if len(choices) == 0: - choices = [*set(self.paths_by_spkr_name[spkr_name])] - """ - raise ValueError( - f"Failed to find another different utterance for {spkr_name}." - ) - """ + choices = [*set(self.metadata[speaker_name].keys())] if not cfg.dataset.prompt_duration_range or cfg.dataset.prompt_duration_range[1] <= 0: should_trim = False @@ -1271,17 +1161,18 @@ class Dataset(_Dataset): for _ in range(cfg.dataset.prompt_max_samples): # yuck - path = None - if reference is not None: + sampled_utterance = None + if utterance_name is not None: if random.random() < cfg.dataset.prompt_similar_p: try: - path = self.get_similar_utterance( reference, offset = len(prom_list) ) + sampled_utterance = self.get_similar_utterance( speaker_name, utterance_name, offset = len(prom_list) ) except Exception as e: - path = None + sampled_utterance = None - if not path: - path = random.choice(choices) + if not sampled_utterance: + sampled_utterance = random.choice(choices) + path = cfg.data_dir / speaker_name / sampled_utterance if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) if key not in cfg.hdf5: @@ -1294,6 +1185,7 @@ class Dataset(_Dataset): except Exception as e: _logger.warning(f'Failed to load artifact: {path} ({e})') path = None + continue if 0 < trim_length and trim_length < qnt.shape[0]: qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) @@ -1321,27 +1213,10 @@ class Dataset(_Dataset): bos_id, space_id, eos_id = self.empty_text - if self.sampler_type == "group": - spkr_group = self.spkr_groups[index] - #spkr_group_id = self.spkr_group_symmap[spkr_group] - spkr_name = self.spkr_samplers[spkr_group].sample() - spkr_id = self.spkr_symmap[spkr_name] - path = self.samplers[spkr_name].sample() - elif self.sampler_type == "speaker": - spkr_name = self.spkrs[index] - spkr_id = self.spkr_symmap[spkr_name] - path = self.samplers[spkr_name].sample() - spkr_group = self.get_speaker_group(path) - #spkr_group_id = self.spkr_group_symmap[spkr_group] - else: - path = self.paths[index] - spkr_name = self.get_speaker(path) - spkr_id = self.spkr_symmap[spkr_name] - spkr_group = self.get_speaker_group(path) - #spkr_group_id = self.spkr_group_symmap[spkr_group] - - if not isinstance( path, Path ): - path = Path( path ) + speaker_id, utterance_id = self.paths[index] + speaker_name = self.speakers[speaker_id] + utterance_name = list(self.metadata[speaker_name].keys())[utterance_id] + path = cfg.data_dir / speaker_name / utterance_name if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) @@ -1378,7 +1253,7 @@ class Dataset(_Dataset): tone = metadata["tone"] if "tone" in metadata else None text_string = metadata["text"] if "text" in metadata else None - lang = self.get_language(spkr_group) if not lang else lang.lower() + lang = lang.lower() if lang else "en" raw_text = torch.tensor(text_tokenize(text_string)).to(torch.int16) if text_string else None @@ -1398,7 +1273,7 @@ class Dataset(_Dataset): if cfg.dataset.resps_max_samples > 1 and random.random() < cfg.dataset.resps_append_p: ignore_paths = [] for _ in range( 1, cfg.dataset.resps_max_samples ): - path, txt, qnt = self.sample_utterance(spkr_name, ignore=ignore_paths) + path, txt, qnt = self.sample_utterance(speaker_name, ignore=ignore_paths) ignore_paths.append(path) # [original text][new text] @@ -1412,7 +1287,7 @@ class Dataset(_Dataset): # might be better to decode => concat waveforms with silence in between => reencode # as you technically can't just append encodec sequences together like this without issues resps = concat_audio( resps, qnt, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) - + task = random.choice(self.tasks) if task not in self.task_symmap: @@ -1420,8 +1295,10 @@ class Dataset(_Dataset): # Base TTS ( => ) if task == "tts": - proms = self.sample_prompts(spkr_name, reference=path) + proms = self.sample_prompts(speaker_name, utterance_name) + """ + proms = self.sample_prompts(speaker_name, reference=path) if random.random() < cfg.dataset.prompt_inject_noise_p: # sample random noise noise = self.sample_noise() @@ -1429,7 +1306,7 @@ class Dataset(_Dataset): noise = repeat_extend_audio(noise, proms.shape[0]) # create the input prompt by merging the target audio with the noise proms = merge_audio( proms, noise, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device ) - + """ # VALL-E Continuous ( => ) # (this could just be sampled as