converting over to a different intermediary dataset format

This commit is contained in:
mrq 2024-04-18 21:24:06 -05:00
parent 4f5c9e518a
commit 8214aa23d7
2 changed files with 60 additions and 33 deletions

View File

@ -63,11 +63,17 @@ def get_task_symmap():
def _replace_file_extension(path, suffix): def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension():
return ".dac"
def _get_phone_extension():
return ".json"
def _get_quant_path(path): def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt") return _replace_file_extension(path, _get_quant_extension())
def _get_phone_path(path): def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt") return _replace_file_extension(path, _get_phone_extension())
_total_durations = {} _total_durations = {}
@ -101,7 +107,7 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False):
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
if len(metadata) == 0: if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate ) return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
def key( dir, id ): def key( dir, id ):
if not cfg.dataset.use_hdf5: if not cfg.dataset.use_hdf5:
@ -134,7 +140,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
key = f"/{type}{_get_hdf5_path(data_dir)}" key = f"/{type}{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else [] return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ): def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
if isinstance(path, str): if isinstance(path, str):
path = Path(path) path = Path(path)
@ -154,6 +160,10 @@ def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else [] return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
def _load_quants(path) -> Tensor: def _load_quants(path) -> Tensor:
if _get_quant_extension() == ".dac":
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16) return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
# prune consecutive spaces # prune consecutive spaces
@ -162,8 +172,12 @@ def _cleanup_phones( phones, targets=[" "]):
@cache @cache
def _get_phones(path, language="en"): def _get_phones(path, language="en"):
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") if _get_quant_extension() == ".json":
content = _cleanup_phones( content ) metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read())
content = metadata["phonemes"]
else:
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
content = _cleanup_phones( content )
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"] return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
def _interleaved_reorder(l, fn): def _interleaved_reorder(l, fn):
@ -807,11 +821,12 @@ def create_dataset_hdf5( skip_existing=True ):
files = os.listdir(f'{root}/{name}/') files = os.listdir(f'{root}/{name}/')
# grab IDs for every file # grab IDs for every file
ids = { ".".join(file.split(".")[:-2]) for file in files } ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
for id in tqdm(ids, desc=f"Processing {name}"): for id in tqdm(ids, desc=f"Processing {name}"):
try: try:
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not audio_exists or not text_exists: if not audio_exists or not text_exists:
continue continue
@ -831,21 +846,34 @@ def create_dataset_hdf5( skip_existing=True ):
# audio # audio
if audios: if audios:
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t() qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t()
if _get_quant_extension() == ".dac":
if "audio" in group:
del group["audio"]
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
metadata[id]["metadata"] = qnt["metadata"]
else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75
if "audio" in group:
del group["audio"]
group.create_dataset('audio', data=qnt.numpy(), compression='lzf') group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
group.attrs['duration'] = qnt.shape[0] # / 75
metadata[id]["duration"] = qnt.shape[0] # / 75 group.attrs['duration'] = duration
metadata[id]["duration"] = duration
else: else:
group.attrs['duration'] = 0 group.attrs['duration'] = 0
metadata[id]["duration"] = 0 metadata[id]["duration"] = 0
# text # text
if texts: if texts:
""" if _get_quant_extension() == ".json":
content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ") j_son = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = j_son["phonemes"]
else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones): for s in set(phones):
if s not in symmap: if s not in symmap:
@ -858,7 +886,6 @@ def create_dataset_hdf5( skip_existing=True ):
group.create_dataset('text', data=phn, compression='lzf', chunks=True) group.create_dataset('text', data=phn, compression='lzf', chunks=True)
group.create_dataset('transcription', data=txt, compression='lzf', chunks=True) group.create_dataset('transcription', data=txt, compression='lzf', chunks=True)
"""
group.attrs['phonemes'] = len(phn) group.attrs['phonemes'] = len(phn)
metadata[id]["phones"] = len(phn) metadata[id]["phones"] = len(phn)

View File

@ -66,22 +66,22 @@ try:
# to-do, original implementation # to-do, original implementation
""" """
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
""" """
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
self.padding = original_padding self.padding = original_padding
return recons return recons
@ -228,7 +228,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version, dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
) )
return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate
kwargs = {} kwargs = {}