more tweaks

This commit is contained in:
mrq 2025-02-06 23:26:26 -06:00
parent 67a9401cce
commit 47eb498046
2 changed files with 28 additions and 18 deletions

View File

@ -693,7 +693,8 @@ def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_artifact_extension():
return ".dac" if cfg.audio_backend == "dac" else ".enc"
#return ".dac" if cfg.audio_backend == "dac" else ".enc"
return cfg.audio_backend_extension
def _get_metadata_extension():
return ".json"
@ -805,11 +806,23 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
def _load_artifact(path, return_metadata=False) -> Tensor:
qnt = np.load(_get_artifact_path(path), allow_pickle=True)[()]
def _load_artifact(path, return_metadata=False, return_artifact=False) -> Tensor:
artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()]
codes = torch.from_numpy(artifact["codes"].astype(int)).to(torch.int16)
# artifact was saved as a batch
if codes.dim() == 3:
codes = codes[0]
# (codebook, frame) => (frame, codebook)
if codes.shape[0] < codes.shape[1]:
codes = codes.t()
if return_artifact:
return codes, artifact
if return_metadata:
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
return codes, artifact["metadata"]
return codes
def _interleaved_reorder(l, fn):
groups = defaultdict(list)
@ -1102,7 +1115,7 @@ class Dataset(_Dataset):
key = _get_hdf5_path(path)
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
else:
qnt = _load_artifact(path, return_metadata=False)
qnt = _load_artifact(path, return_metadata=False, return_artifact=False)
return qnt
def sample_speakers(self, ignore=[]):
@ -1758,9 +1771,7 @@ def create_dataset_metadata( skip_existing=False ):
utterance_metadata = {}
if audios:
artifact = np.load(quant_path, allow_pickle=True)[()]
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
qnt, artifact = _load_artifact(quant_path, return_artifact=True)
utterance_metadata = process_artifact_metadata( artifact )
# to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R
utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second
@ -1866,9 +1877,7 @@ def create_dataset_hdf5( skip_existing=True ):
# audio
if audios:
artifact = np.load(f'{root}/{name}/{id}{_get_artifact_extension()}', allow_pickle=True)[()]
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
qnt, artifact = _load_artifact(f'{root}/{name}/{id}{_get_artifact_extension()}', return_artifact=True)
utterance_metadata = process_artifact_metadata( artifact )
if "audio" not in group:

View File

@ -958,6 +958,7 @@ def example_usage():
from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
from ..data import _load_artifact
from ..engines import Engine, Engines
from ..utils import wrapper as ml
from ..utils import setup_logging
@ -972,19 +973,19 @@ def example_usage():
setup_logging()
def load_artifact( path ):
artifact = np.load(path, allow_pickle=True)[()]
audio, metadata = _load_artifact(path, return_metadata=True)
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device)
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=cfg.device)
audio = audio.to(cfg.device)
text = torch.tensor( cfg.tokenizer.encode( metadata["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device)
return text, audio
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}")
batch_size = cfg.hyperparameters.batch_size
text_list = [ text ] * batch_size
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
resps_list = [ audio[:cfg.dataset.frames_per_second * 4, :] ] * batch_size
proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size
resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size
kwargs = {
'n_text_tokens': 256,