more tweaks
This commit is contained in:
parent
67a9401cce
commit
47eb498046
|
@ -693,7 +693,8 @@ def _replace_file_extension(path, suffix):
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
def _get_artifact_extension():
|
def _get_artifact_extension():
|
||||||
return ".dac" if cfg.audio_backend == "dac" else ".enc"
|
#return ".dac" if cfg.audio_backend == "dac" else ".enc"
|
||||||
|
return cfg.audio_backend_extension
|
||||||
|
|
||||||
def _get_metadata_extension():
|
def _get_metadata_extension():
|
||||||
return ".json"
|
return ".json"
|
||||||
|
@ -805,11 +806,23 @@ def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), valida
|
||||||
|
|
||||||
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
||||||
|
|
||||||
def _load_artifact(path, return_metadata=False) -> Tensor:
|
def _load_artifact(path, return_metadata=False, return_artifact=False) -> Tensor:
|
||||||
qnt = np.load(_get_artifact_path(path), allow_pickle=True)[()]
|
artifact = np.load(_get_artifact_path(path), allow_pickle=True)[()]
|
||||||
|
codes = torch.from_numpy(artifact["codes"].astype(int)).to(torch.int16)
|
||||||
|
# artifact was saved as a batch
|
||||||
|
if codes.dim() == 3:
|
||||||
|
codes = codes[0]
|
||||||
|
# (codebook, frame) => (frame, codebook)
|
||||||
|
if codes.shape[0] < codes.shape[1]:
|
||||||
|
codes = codes.t()
|
||||||
|
|
||||||
|
if return_artifact:
|
||||||
|
return codes, artifact
|
||||||
|
|
||||||
if return_metadata:
|
if return_metadata:
|
||||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
|
return codes, artifact["metadata"]
|
||||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
|
||||||
|
return codes
|
||||||
|
|
||||||
def _interleaved_reorder(l, fn):
|
def _interleaved_reorder(l, fn):
|
||||||
groups = defaultdict(list)
|
groups = defaultdict(list)
|
||||||
|
@ -1102,7 +1115,7 @@ class Dataset(_Dataset):
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_artifact(path, return_metadata=False)
|
qnt = _load_artifact(path, return_metadata=False, return_artifact=False)
|
||||||
return qnt
|
return qnt
|
||||||
|
|
||||||
def sample_speakers(self, ignore=[]):
|
def sample_speakers(self, ignore=[]):
|
||||||
|
@ -1758,9 +1771,7 @@ def create_dataset_metadata( skip_existing=False ):
|
||||||
|
|
||||||
utterance_metadata = {}
|
utterance_metadata = {}
|
||||||
if audios:
|
if audios:
|
||||||
artifact = np.load(quant_path, allow_pickle=True)[()]
|
qnt, artifact = _load_artifact(quant_path, return_artifact=True)
|
||||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
|
||||||
|
|
||||||
utterance_metadata = process_artifact_metadata( artifact )
|
utterance_metadata = process_artifact_metadata( artifact )
|
||||||
# to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R
|
# to-do: derive duration from codes if duration is malformed because this happened to me with LibriTTS-R
|
||||||
utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second
|
utterance_metadata["duration"] = qnt.shape[0] / cfg.dataset.frames_per_second
|
||||||
|
@ -1866,9 +1877,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if audios:
|
if audios:
|
||||||
artifact = np.load(f'{root}/{name}/{id}{_get_artifact_extension()}', allow_pickle=True)[()]
|
qnt, artifact = _load_artifact(f'{root}/{name}/{id}{_get_artifact_extension()}', return_artifact=True)
|
||||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
|
||||||
|
|
||||||
utterance_metadata = process_artifact_metadata( artifact )
|
utterance_metadata = process_artifact_metadata( artifact )
|
||||||
|
|
||||||
if "audio" not in group:
|
if "audio" not in group:
|
||||||
|
|
|
@ -958,6 +958,7 @@ def example_usage():
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||||
|
from ..data import _load_artifact
|
||||||
from ..engines import Engine, Engines
|
from ..engines import Engine, Engines
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
from ..utils import setup_logging
|
from ..utils import setup_logging
|
||||||
|
@ -972,19 +973,19 @@ def example_usage():
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
def load_artifact( path ):
|
def load_artifact( path ):
|
||||||
artifact = np.load(path, allow_pickle=True)[()]
|
audio, metadata = _load_artifact(path, return_metadata=True)
|
||||||
|
|
||||||
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device)
|
audio = audio.to(cfg.device)
|
||||||
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=cfg.device)
|
text = torch.tensor( cfg.tokenizer.encode( metadata["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device)
|
||||||
|
|
||||||
return text, audio
|
return text, audio
|
||||||
|
|
||||||
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}")
|
||||||
batch_size = cfg.hyperparameters.batch_size
|
batch_size = cfg.hyperparameters.batch_size
|
||||||
|
|
||||||
text_list = [ text ] * batch_size
|
text_list = [ text ] * batch_size
|
||||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size
|
||||||
resps_list = [ audio[:cfg.dataset.frames_per_second * 4, :] ] * batch_size
|
resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_text_tokens': 256,
|
'n_text_tokens': 256,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user