diff --git a/vall_e/data.py b/vall_e/data.py
index bbbe16b..b2c12f3 100755
--- a/vall_e/data.py
+++ b/vall_e/data.py
@@ -63,7 +63,7 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension():
- return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
+ return ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
def _get_phone_extension():
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
@@ -161,25 +161,28 @@ def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
-def _load_quants(path) -> Tensor:
- if _get_quant_extension() == ".dac":
- qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
- return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
-
- return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
+def _load_quants(path, return_metadata=False) -> Tensor:
+ qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
+ if return_metadata:
+ return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
+ return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
# prune consecutive spaces
def _cleanup_phones( phones, targets=[" "]):
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
@cache
-def _get_phones(path, language="en"):
- if _get_quant_extension() == ".json":
- metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read())
- content = metadata["phonemes"]
+def _get_phones(path):
+ phone_path = _get_phone_path(path)
+ quant_path = _get_quant_path(path)
+ if phone_path.exists():
+ metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
+ elif quant_path.exists():
+ _, metadata = _load_quants( path, return_metadata=True )
else:
- content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
+ raise Exception(f"Could not load phonemes: {path}")
+ content = metadata["phonemes"]
return "".join(content)
def _interleaved_reorder(l, fn):
@@ -269,9 +272,11 @@ class Dataset(_Dataset):
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self.duration = _calculate_durations(self.dataset_type)
+ """
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
+ """
def get_speaker(self, path):
if isinstance(path, str):
@@ -350,7 +355,7 @@ class Dataset(_Dataset):
key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
- qnt = _load_quants(path)
+ qnt = _load_quants(path, return_metadata=False)
return qnt
def sample_speakers(self, ignore=[]):
@@ -386,7 +391,7 @@ class Dataset(_Dataset):
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
- qnt = _load_quants(path)
+ qnt = _load_quants(path, return_metadata=False)
if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length )
@@ -438,8 +443,9 @@ class Dataset(_Dataset):
text = torch.from_numpy(text).to(self.text_dtype)
resps = torch.from_numpy(resps).to(torch.int16)
else:
- text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
- resps = _load_quants(path)
+ resps, metadata = _load_quants(path, return_metadata=True)
+ text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
+ #text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
@@ -462,8 +468,9 @@ class Dataset(_Dataset):
qnt = torch.from_numpy(qnt).to(torch.int16)
else:
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
- txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
- qnt = _load_quants(sampled_path)
+ #txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
+ qnt, metadata = _load_quants(sampled_path, return_metadata=True)
+ txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
# [original text] [new text]
# removes the original text's , includes a space, and remove the new text's
@@ -788,10 +795,10 @@ def create_dataset_metadata( skip_existing=True ):
for id in tqdm(ids, desc=f"Processing {name}"):
try:
- audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
+ quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
- if not audio_exists or not text_exists:
+ if not quant_exists:
continue
key = f'{type}/{speaker_name}/{id}'
@@ -817,9 +824,8 @@ def create_dataset_metadata( skip_existing=True ):
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
# text
- if texts:
- if not utterance_metadata:
- utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
+ if texts and text_exists and not utterance_metadata:
+ utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
for k, v in utterance_metadata.items():
metadata[id][k] = v
@@ -878,27 +884,19 @@ def create_dataset_hdf5( skip_existing=True ):
for id in tqdm(ids, desc=f"Processing {name}"):
try:
- audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}')
- text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if type != "Noise" else True
+ quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
+ text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
- if not audio_exists:
+ if not quant_exists:
continue
key = f'{type}/{speaker_name}/{id}'
- """
if skip_existing and key in hf:
continue
- """
group = hf.create_group(key) if key not in hf else hf[key]
- """
- group.attrs['id'] = id
- group.attrs['type'] = type
- group.attrs['speaker'] = speaker_name
- """
-
if id not in metadata:
metadata[id] = {}
@@ -906,7 +904,6 @@ def create_dataset_hdf5( skip_existing=True ):
# audio
if audios:
- # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)