fix loading without needing an hdf5 dataset already prepped (and some other incidental speedups during dataloader prep)

This commit is contained in:
mrq 2024-05-18 07:14:26 -05:00
parent d88a5ca183
commit 4bc7e5a6d1

View File

@ -63,7 +63,7 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension(): def _get_quant_extension():
return ".dac" if cfg.inference.audio_backend == "dac" else ".qnt.pt" return ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
def _get_phone_extension(): def _get_phone_extension():
return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt" return ".json" # if cfg.inference.audio_backend == "dac" else ".phn.txt"
@ -161,25 +161,28 @@ def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else [] return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
def _load_quants(path) -> Tensor: def _load_quants(path, return_metadata=False) -> Tensor:
if _get_quant_extension() == ".dac":
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()] qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
if return_metadata:
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
# prune consecutive spaces # prune consecutive spaces
def _cleanup_phones( phones, targets=[" "]): def _cleanup_phones( phones, targets=[" "]):
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ] return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
@cache @cache
def _get_phones(path, language="en"): def _get_phones(path):
if _get_quant_extension() == ".json": phone_path = _get_phone_path(path)
metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read()) quant_path = _get_quant_path(path)
content = metadata["phonemes"] if phone_path.exists():
metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
elif quant_path.exists():
_, metadata = _load_quants( path, return_metadata=True )
else: else:
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") raise Exception(f"Could not load phonemes: {path}")
content = metadata["phonemes"]
return "".join(content) return "".join(content)
def _interleaved_reorder(l, fn): def _interleaved_reorder(l, fn):
@ -269,9 +272,11 @@ class Dataset(_Dataset):
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0 #self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
self.duration = _calculate_durations(self.dataset_type) self.duration = _calculate_durations(self.dataset_type)
"""
@cached_property @cached_property
def phones(self): def phones(self):
return sorted(set().union(*[_get_phones(path) for path in self.paths])) return sorted(set().union(*[_get_phones(path) for path in self.paths]))
"""
def get_speaker(self, path): def get_speaker(self, path):
if isinstance(path, str): if isinstance(path, str):
@ -350,7 +355,7 @@ class Dataset(_Dataset):
key = _get_hdf5_path(path) key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else: else:
qnt = _load_quants(path) qnt = _load_quants(path, return_metadata=False)
return qnt return qnt
def sample_speakers(self, ignore=[]): def sample_speakers(self, ignore=[]):
@ -386,7 +391,7 @@ class Dataset(_Dataset):
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else: else:
qnt = _load_quants(path) qnt = _load_quants(path, return_metadata=False)
if 0 < trim_length and trim_length < qnt.shape[0]: if 0 < trim_length and trim_length < qnt.shape[0]:
qnt = trim( qnt, trim_length ) qnt = trim( qnt, trim_length )
@ -438,8 +443,9 @@ class Dataset(_Dataset):
text = torch.from_numpy(text).to(self.text_dtype) text = torch.from_numpy(text).to(self.text_dtype)
resps = torch.from_numpy(resps).to(torch.int16) resps = torch.from_numpy(resps).to(torch.int16)
else: else:
text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype) resps, metadata = _load_quants(path, return_metadata=True)
resps = _load_quants(path) text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
#text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype)
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
@ -462,8 +468,9 @@ class Dataset(_Dataset):
qnt = torch.from_numpy(qnt).to(torch.int16) qnt = torch.from_numpy(qnt).to(torch.int16)
else: else:
#txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype) #txt = torch.tensor([*map(self.phone_symmap.get, _get_phones(sampled_path))]).to(self.text_dtype)
txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype) #txt = torch.tensor(tokenize(_get_phones(sampled_path))).to(self.text_dtype)
qnt = _load_quants(sampled_path) qnt, metadata = _load_quants(sampled_path, return_metadata=True)
txt = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
# <s>[original text] [new text]</s> # <s>[original text] [new text]</s>
# removes the original text's </s>, includes a space, and remove the new text's <s> # removes the original text's </s>, includes a space, and remove the new text's <s>
@ -788,10 +795,10 @@ def create_dataset_metadata( skip_existing=True ):
for id in tqdm(ids, desc=f"Processing {name}"): for id in tqdm(ids, desc=f"Processing {name}"):
try: try:
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not audio_exists or not text_exists: if not quant_exists:
continue continue
key = f'{type}/{speaker_name}/{id}' key = f'{type}/{speaker_name}/{id}'
@ -817,8 +824,7 @@ def create_dataset_metadata( skip_existing=True ):
if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]: if "original_length" in dac["metadata"] and "sample_rate" in dac["metadata"]:
utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] utterance_metadata["duration"] = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"]
# text # text
if texts: if texts and text_exists and not utterance_metadata:
if not utterance_metadata:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
for k, v in utterance_metadata.items(): for k, v in utterance_metadata.items():
@ -878,27 +884,19 @@ def create_dataset_hdf5( skip_existing=True ):
for id in tqdm(ids, desc=f"Processing {name}"): for id in tqdm(ids, desc=f"Processing {name}"):
try: try:
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if type != "Noise" else True text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not audio_exists: if not quant_exists:
continue continue
key = f'{type}/{speaker_name}/{id}' key = f'{type}/{speaker_name}/{id}'
"""
if skip_existing and key in hf: if skip_existing and key in hf:
continue continue
"""
group = hf.create_group(key) if key not in hf else hf[key] group = hf.create_group(key) if key not in hf else hf[key]
"""
group.attrs['id'] = id
group.attrs['type'] = type
group.attrs['speaker'] = speaker_name
"""
if id not in metadata: if id not in metadata:
metadata[id] = {} metadata[id] = {}
@ -906,7 +904,6 @@ def create_dataset_hdf5( skip_existing=True ):
# audio # audio
if audios: if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16)