fix loading without needing an hdf5 dataset already prepped (and some other incidental speedups during dataloader prep)

This commit is contained in:
mrq 2024-05-18 07:14:26 -05:00
parent d88a5ca183
commit 4bc7e5a6d1

View File

@ -63,7 +63,7 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension():
return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
return ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
def _get_phone_extension():
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
@ -161,25 +161,28 @@ def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
def _load_quants(path) -> Tensor:
if _get_quant_extension() == ".dac":
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
def _load_quants(path, return_metadata=False) -> Tensor:
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
if return_metadata:
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
# prune consecutive spaces
def _cleanup_phones( phones, targets=[" "]):
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
@cache
def _get_phones(path, language="en"):
if _get_quant_extension() == ".json":
metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read())
content = metadata["phonemes"]
def _get_phones(path):
phone_path = _get_phone_path(path)
quant_path = _get_quant_path(path)
if phone_path.exists():
metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
elif quant_path.exists():
_, metadata = _load_quants( path, return_metadata=True )
else:
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
raise Exception(f"Could not load phonemes: {path}")
content = metadata["phonemes"]
return "".join(content)
def _interleaved_reorder(l, fn):
@ -269,9 +272,11 @@ class Dataset(_Dataset):
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self.duration = _calculate_durations(self.dataset_type)
"""
@cached_property
def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
"""
def get_speaker(self, path):
if isinstance(path, str):
@ -350,7 +355,7 @@ class Dataset(_Dataset):
key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_quants(path)
qnt = _load_quants(path, return_metadata=False)
return qnt
def sample_speakers(self, ignore=[]):
@ -386,7 +391,7 @@ class Dataset(_Dataset):
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_quants(path)
qnt = _load_quants(path, return_metadata=False)
if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length )
@ -438,8 +443,9 @@ class Dataset(_Dataset):
text = torch.from_numpy(text).to(self.text_dtype)
resps = torch.from_numpy(resps).to(torch.int16)
else:
text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
resps = _load_quants(path)
resps, metadata = _load_quants(path, return_metadata=True)
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
@ -462,8 +468,9 @@ class Dataset(_Dataset):
qnt = torch.from_numpy(qnt).to(torch.int16)
else:
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
qnt = _load_quants(sampled_path)
#txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
qnt, metadata = _load_quants(sampled_path, return_metadata=True)
txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
# <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s>
@ -788,10 +795,10 @@ def create_dataset_metadata( skip_existing=True ):
for id in tqdm(ids, desc=f"Processing {name}"):
try:
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not audio_exists or not text_exists:
if not quant_exists:
continue
key = f'{type}/{speaker_name}/{id}'
@ -817,9 +824,8 @@ def create_dataset_metadata( skip_existing=True ):
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
# text
if texts:
if not utterance_metadata:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
if texts and text_exists and not utterance_metadata:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
for k, v in utterance_metadata.items():
metadata[id][k] = v
@ -878,27 +884,19 @@ def create_dataset_hdf5( skip_existing=True ):
for id in tqdm(ids, desc=f"Processing {name}"):
try:
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}')
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if type != "Noise" else True
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not audio_exists:
if not quant_exists:
continue
key = f'{type}/{speaker_name}/{id}'
"""
if skip_existing and key in hf:
continue
"""
group = hf.create_group(key) if key not in hf else hf[key]
"""
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = speaker_name
"""
if id not in metadata:
metadata[id] = {}
@ -906,7 +904,6 @@ def create_dataset_hdf5( skip_existing=True ):
# audio
if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)