From 47eb49804651f63fa579953b5840e78031b98ff6 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 6 Feb 2025 23:26:26 -0600 Subject: [PATCH] more tweaks --- vall_e/data.py | 33 +++++++++++++++++++++------------ vall_e/models/ar_nar.py | 13 +++++++------ 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index c3584f4..e7b58fe 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -693,7 +693,8 @@ def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) def _get_artifact_extension(): - return ".dac" if cfg.audio_backend == "dac" else ".enc" + #return ".dac" if cfg.audio_backend == "dac" else ".enc" + return cfg.audio_backend_extension def _get_metadata_extension(): return ".json" @@ -805,11 +806,23 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] -def _load_artifact(path, return_metadata=False) -> Tensor: - qnt = np.load(_get_artifact_path(path), allow_pickle=True)[()] +def _load_artifact(path, return_metadata=False, return_artifact=False) -> Tensor: + artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()] + codes = torch.from_numpy(artifact["codes"].astype(int)).to(torch.int16) + # artifact was saved as a batch + if codes.dim() == 3: + codes = codes[0] + # (codebook, frame) => (frame, codebook) + if codes.shape[0] < codes.shape[1]: + codes = codes.t() + + if return_artifact: + return codes, artifact + if return_metadata: - return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"] - return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) + return codes, artifact["metadata"] + + return codes def _interleaved_reorder(l, fn): groups = defaultdict(list) @@ -1102,7 +1115,7 @@ class Dataset(_Dataset): key = _get_hdf5_path(path) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: - qnt = _load_artifact(path, return_metadata=False) + qnt = _load_artifact(path, return_metadata=False, return_artifact=False) return qnt def sample_speakers(self, ignore=[]): @@ -1758,9 +1771,7 @@ def create_dataset_metadata( skip_existing=False ): utterance_metadata = {} if audios: - artifact = np.load(quant_path, allow_pickle=True)[()] - qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16) - + qnt, artifact = _load_artifact(quant_path, return_artifact=True) utterance_metadata = process_artifact_metadata( artifact ) # to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second @@ -1866,9 +1877,7 @@ def create_dataset_hdf5( skip_existing=True ): # audio if audios: - artifact = np.load(f'{root}/{name}/{id}{_get_artifact_extension()}', allow_pickle=True)[()] - qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16) - + qnt, artifact = _load_artifact(f'{root}/{name}/{id}{_get_artifact_extension()}', return_artifact=True) utterance_metadata = process_artifact_metadata( artifact ) if "audio" not in group: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4c8d9ec..960c566 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -958,6 +958,7 @@ def example_usage(): from tqdm import tqdm from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio + from ..data import _load_artifact from ..engines import Engine, Engines from ..utils import wrapper as ml from ..utils import setup_logging @@ -972,19 +973,19 @@ def example_usage(): setup_logging() def load_artifact( path ): - artifact = np.load(path, allow_pickle=True)[()] + audio, metadata = _load_artifact(path, return_metadata=True) - text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device) - audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=cfg.device) + audio = audio.to(cfg.device) + text = torch.tensor( cfg.tokenizer.encode( metadata["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device) return text, audio - text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}") batch_size = cfg.hyperparameters.batch_size text_list = [ text ] * batch_size - proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size - resps_list = [ audio[:cfg.dataset.frames_per_second * 4, :] ] * batch_size + proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size + resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size kwargs = { 'n_text_tokens': 256,