fix loading without needing an hdf5 dataset already prepped (and some other incidental speedups during dataloader prep)
This commit is contained in:
parent
d88a5ca183
commit
4bc7e5a6d1
|
@ -63,7 +63,7 @@ def _replace_file_extension(path, suffix):
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
def _get_quant_extension():
|
def _get_quant_extension():
|
||||||
return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt"
|
return ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||||
|
|
||||||
def _get_phone_extension():
|
def _get_phone_extension():
|
||||||
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
|
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
|
||||||
|
@ -161,25 +161,28 @@ def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=
|
||||||
|
|
||||||
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
|
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
|
||||||
|
|
||||||
def _load_quants(path) -> Tensor:
|
def _load_quants(path, return_metadata=False) -> Tensor:
|
||||||
if _get_quant_extension() == ".dac":
|
|
||||||
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
||||||
|
if return_metadata:
|
||||||
|
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
|
||||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
||||||
|
|
||||||
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
|
|
||||||
|
|
||||||
# prune consecutive spaces
|
# prune consecutive spaces
|
||||||
def _cleanup_phones( phones, targets=[" "]):
|
def _cleanup_phones( phones, targets=[" "]):
|
||||||
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
|
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _get_phones(path, language="en"):
|
def _get_phones(path):
|
||||||
if _get_quant_extension() == ".json":
|
phone_path = _get_phone_path(path)
|
||||||
metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read())
|
quant_path = _get_quant_path(path)
|
||||||
content = metadata["phonemes"]
|
if phone_path.exists():
|
||||||
|
metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
|
||||||
|
elif quant_path.exists():
|
||||||
|
_, metadata = _load_quants( path, return_metadata=True )
|
||||||
else:
|
else:
|
||||||
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
|
raise Exception(f"Could not load phonemes: {path}")
|
||||||
|
|
||||||
|
content = metadata["phonemes"]
|
||||||
return "".join(content)
|
return "".join(content)
|
||||||
|
|
||||||
def _interleaved_reorder(l, fn):
|
def _interleaved_reorder(l, fn):
|
||||||
|
@ -269,9 +272,11 @@ class Dataset(_Dataset):
|
||||||
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
||||||
self.duration = _calculate_durations(self.dataset_type)
|
self.duration = _calculate_durations(self.dataset_type)
|
||||||
|
|
||||||
|
"""
|
||||||
@cached_property
|
@cached_property
|
||||||
def phones(self):
|
def phones(self):
|
||||||
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
return sorted(set().union(*[_get_phones(path) for path in self.paths]))
|
||||||
|
"""
|
||||||
|
|
||||||
def get_speaker(self, path):
|
def get_speaker(self, path):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
|
@ -350,7 +355,7 @@ class Dataset(_Dataset):
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path, return_metadata=False)
|
||||||
return qnt
|
return qnt
|
||||||
|
|
||||||
def sample_speakers(self, ignore=[]):
|
def sample_speakers(self, ignore=[]):
|
||||||
|
@ -386,7 +391,7 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path, return_metadata=False)
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||||
qnt = trim( qnt, trim_length )
|
qnt = trim( qnt, trim_length )
|
||||||
|
@ -438,8 +443,9 @@ class Dataset(_Dataset):
|
||||||
text = torch.from_numpy(text).to(self.text_dtype)
|
text = torch.from_numpy(text).to(self.text_dtype)
|
||||||
resps = torch.from_numpy(resps).to(torch.int16)
|
resps = torch.from_numpy(resps).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
|
resps, metadata = _load_quants(path, return_metadata=True)
|
||||||
resps = _load_quants(path)
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
|
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
|
||||||
|
|
||||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||||
|
|
||||||
|
@ -462,8 +468,9 @@ class Dataset(_Dataset):
|
||||||
qnt = torch.from_numpy(qnt).to(torch.int16)
|
qnt = torch.from_numpy(qnt).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
|
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
|
||||||
txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
|
#txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
|
||||||
qnt = _load_quants(sampled_path)
|
qnt, metadata = _load_quants(sampled_path, return_metadata=True)
|
||||||
|
txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
|
|
||||||
# <s>[original text] [new text]</s>
|
# <s>[original text] [new text]</s>
|
||||||
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
# removes the original text's </s>, includes a space, and remove the new text's <s>
|
||||||
|
@ -788,10 +795,10 @@ def create_dataset_metadata( skip_existing=True ):
|
||||||
|
|
||||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||||
try:
|
try:
|
||||||
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||||
|
|
||||||
if not audio_exists or not text_exists:
|
if not quant_exists:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key = f'{type}/{speaker_name}/{id}'
|
key = f'{type}/{speaker_name}/{id}'
|
||||||
|
@ -817,8 +824,7 @@ def create_dataset_metadata( skip_existing=True ):
|
||||||
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
|
||||||
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
|
||||||
# text
|
# text
|
||||||
if texts:
|
if texts and text_exists and not utterance_metadata:
|
||||||
if not utterance_metadata:
|
|
||||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||||
|
|
||||||
for k, v in utterance_metadata.items():
|
for k, v in utterance_metadata.items():
|
||||||
|
@ -878,27 +884,19 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||||
try:
|
try:
|
||||||
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}')
|
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if type != "Noise" else True
|
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||||
|
|
||||||
if not audio_exists:
|
if not quant_exists:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key = f'{type}/{speaker_name}/{id}'
|
key = f'{type}/{speaker_name}/{id}'
|
||||||
|
|
||||||
"""
|
|
||||||
if skip_existing and key in hf:
|
if skip_existing and key in hf:
|
||||||
continue
|
continue
|
||||||
"""
|
|
||||||
|
|
||||||
group = hf.create_group(key) if key not in hf else hf[key]
|
group = hf.create_group(key) if key not in hf else hf[key]
|
||||||
|
|
||||||
"""
|
|
||||||
group.attrs['id'] = id
|
|
||||||
group.attrs['type'] = type
|
|
||||||
group.attrs['speaker'] = speaker_name
|
|
||||||
"""
|
|
||||||
|
|
||||||
if id not in metadata:
|
if id not in metadata:
|
||||||
metadata[id] = {}
|
metadata[id] = {}
|
||||||
|
|
||||||
|
@ -906,7 +904,6 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if audios:
|
if audios:
|
||||||
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
|
||||||
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||||
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user