fix loading without needing an hdf5 dataset already prepped (and some other incidental speedups during dataloader prep)
This commit is contained in:
parent
d88a5ca183
commit
4bc7e5a6d1
|
@ -63,7 +63,7 @@ def _replace_file_extension(path, suffix):
|
|||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_quant_extension():
|
||||
return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
|
||||
return ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
|
||||
def _get_phone_extension():
|
||||
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
|
||||
|
@ -161,25 +161,28 @@ def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=
|
|||
|
||||
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
if _get_quant_extension() == ".dac":
|
||||
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
||||
|
||||
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
|
||||
def _load_quants(path, return_metadata=False) -> Tensor:
|
||||
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
||||
if return_metadata:
|
||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
|
||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
||||
|
||||
# prune consecutive spaces
|
||||
def _cleanup_phones( phones, targets=[" "]):
|
||||
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
|
||||
|
||||
@cache
|
||||
def _get_phones(path, language="en"):
|
||||
if _get_quant_extension() == ".json":
|
||||
metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read())
|
||||
content = metadata["phonemes"]
|
||||
def _get_phones(path):
|
||||
phone_path = _get_phone_path(path)
|
||||
quant_path = _get_quant_path(path)
|
||||
if phone_path.exists():
|
||||
metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
|
||||
elif quant_path.exists():
|
||||
_, metadata = _load_quants( path, return_metadata=True )
|
||||
else:
|
||||
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
|
||||
raise Exception(f"Could not load phonemes: {path}")
|
||||
|
||||
content = metadata["phonemes"]
|
||||
return "".join(content)
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
|
@ -269,9 +272,11 @@ class Dataset(_Dataset):
|
|||
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
||||
self.duration = _calculate_durations(self.dataset_type)
|
||||
|
||||
"""
|
||||
@cached_property
|
||||
def phones(self):
|
||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||
"""
|
||||
|
||||
def get_speaker(self, path):
|
||||
if isinstance(path, str):
|
||||
|
@ -350,7 +355,7 @@ class Dataset(_Dataset):
|
|||
key = _get_hdf5_path(path)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
qnt = _load_quants(path, return_metadata=False)
|
||||
return qnt
|
||||
|
||||
def sample_speakers(self, ignore=[]):
|
||||
|
@ -386,7 +391,7 @@ class Dataset(_Dataset):
|
|||
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path)
|
||||
qnt = _load_quants(path, return_metadata=False)
|
||||
|
||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||
qnt = trim( qnt, trim_length )
|
||||
|
@ -438,8 +443,9 @@ class Dataset(_Dataset):
|
|||
text = torch.from_numpy(text).to(self.text_dtype)
|
||||
resps = torch.from_numpy(resps).to(torch.int16)
|
||||
else:
|
||||
text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
resps, metadata = _load_quants(path, return_metadata=True)
|
||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
|
||||
|
||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||
|
||||
|
@ -462,8 +468,9 @@ class Dataset(_Dataset):
|
|||
qnt = torch.from_numpy(qnt).to(torch.int16)
|
||||
else:
|
||||
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
|
||||
txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
|
||||
qnt = _load_quants(sampled_path)
|
||||
#txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
|
||||
qnt, metadata = _load_quants(sampled_path, return_metadata=True)
|
||||
txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||
|
||||
# <s>[original text] [new text]</s>
|
||||
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
||||
|
@ -788,10 +795,10 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||
|
||||
if not audio_exists or not text_exists:
|
||||
if not quant_exists:
|
||||
continue
|
||||
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
@ -817,9 +824,8 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
||||
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
||||
# text
|
||||
if texts:
|
||||
if not utterance_metadata:
|
||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||
if texts and text_exists and not utterance_metadata:
|
||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||
|
||||
for k, v in utterance_metadata.items():
|
||||
metadata[id][k] = v
|
||||
|
@ -878,27 +884,19 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}')
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if type != "Noise" else True
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||
|
||||
if not audio_exists:
|
||||
if not quant_exists:
|
||||
continue
|
||||
|
||||
key = f'{type}/{speaker_name}/{id}'
|
||||
|
||||
"""
|
||||
if skip_existing and key in hf:
|
||||
continue
|
||||
"""
|
||||
|
||||
group = hf.create_group(key) if key not in hf else hf[key]
|
||||
|
||||
"""
|
||||
group.attrs['id'] = id
|
||||
group.attrs['type'] = type
|
||||
group.attrs['speaker'] = speaker_name
|
||||
"""
|
||||
|
||||
if id not in metadata:
|
||||
metadata[id] = {}
|
||||
|
||||
|
@ -906,7 +904,6 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
# audio
|
||||
if audios:
|
||||
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
||||
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user