diff --git a/vall_e/data.py b/vall_e/data.py index bbbe16b..b2c12f3 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -63,7 +63,7 @@ def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) def _get_quant_extension(): - return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt" + return ".dac" if cfg.inference.audio_backend == "dac" else ".enc" def _get_phone_extension(): return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt" @@ -161,25 +161,28 @@ def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate= return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else [] -def _load_quants(path) -> Tensor: - if _get_quant_extension() == ".dac": - qnt = np.load(_get_quant_path(path), allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) - - return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16) +def _load_quants(path, return_metadata=False) -> Tensor: + qnt = np.load(_get_quant_path(path), allow_pickle=True)[()] + if return_metadata: + return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"] + return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) # prune consecutive spaces def _cleanup_phones( phones, targets=[" "]): return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ] @cache -def _get_phones(path, language="en"): - if _get_quant_extension() == ".json": - metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read()) - content = metadata["phonemes"] +def _get_phones(path): + phone_path = _get_phone_path(path) + quant_path = _get_quant_path(path) + if phone_path.exists(): + metadata = json.loads(open(phone_path, "r", encoding="utf-8").read()) + elif quant_path.exists(): + _, metadata = _load_quants( path, return_metadata=True ) else: - content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") + raise Exception(f"Could not load phonemes: {path}") + content = metadata["phonemes"] return "".join(content) def _interleaved_reorder(l, fn): @@ -269,9 +272,11 @@ class Dataset(_Dataset): #self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0 self.duration = _calculate_durations(self.dataset_type) + """ @cached_property def phones(self): return sorted(set().union(*[_get_phones(path) for path in self.paths])) + """ def get_speaker(self, path): if isinstance(path, str): @@ -350,7 +355,7 @@ class Dataset(_Dataset): key = _get_hdf5_path(path) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: - qnt = _load_quants(path) + qnt = _load_quants(path, return_metadata=False) return qnt def sample_speakers(self, ignore=[]): @@ -386,7 +391,7 @@ class Dataset(_Dataset): qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: - qnt = _load_quants(path) + qnt = _load_quants(path, return_metadata=False) if 0 < trim_length and trim_length < qnt.shape[0]: qnt = trim( qnt, trim_length ) @@ -438,8 +443,9 @@ class Dataset(_Dataset): text = torch.from_numpy(text).to(self.text_dtype) resps = torch.from_numpy(resps).to(torch.int16) else: - text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype) - resps = _load_quants(path) + resps, metadata = _load_quants(path, return_metadata=True) + text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype) + #text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype) lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) @@ -462,8 +468,9 @@ class Dataset(_Dataset): qnt = torch.from_numpy(qnt).to(torch.int16) else: #txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype) - txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype) - qnt = _load_quants(sampled_path) + #txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype) + qnt, metadata = _load_quants(sampled_path, return_metadata=True) + txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype) # [original text] [new text] # removes the original text's , includes a space, and remove the new text's @@ -788,10 +795,10 @@ def create_dataset_metadata( skip_existing=True ): for id in tqdm(ids, desc=f"Processing {name}"): try: - audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True + quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True - if not audio_exists or not text_exists: + if not quant_exists: continue key = f'{type}/{speaker_name}/{id}' @@ -817,9 +824,8 @@ def create_dataset_metadata( skip_existing=True ): if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] # text - if texts: - if not utterance_metadata: - utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) + if texts and text_exists and not utterance_metadata: + utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) for k, v in utterance_metadata.items(): metadata[id][k] = v @@ -878,27 +884,19 @@ def create_dataset_hdf5( skip_existing=True ): for id in tqdm(ids, desc=f"Processing {name}"): try: - audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') - text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if type != "Noise" else True + quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True + text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True - if not audio_exists: + if not quant_exists: continue key = f'{type}/{speaker_name}/{id}' - """ if skip_existing and key in hf: continue - """ group = hf.create_group(key) if key not in hf else hf[key] - """ - group.attrs['id'] = id - group.attrs['type'] = type - group.attrs['speaker'] = speaker_name - """ - if id not in metadata: metadata[id] = {} @@ -906,7 +904,6 @@ def create_dataset_hdf5( skip_existing=True ): # audio if audios: - # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)