converting over to a different intermediary dataset format

This commit is contained in:
mrq 2024-04-18 21:24:06 -05:00
parent 4f5c9e518a
commit 8214aa23d7
2 changed files with 60 additions and 33 deletions

View File

@ -63,11 +63,17 @@ def get_task_symmap():
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def _get_quant_extension():
return ".dac"
def _get_phone_extension():
return ".json"
def _get_quant_path(path):
return _replace_file_extension(path, ".qnt.pt")
return _replace_file_extension(path, _get_quant_extension())
def _get_phone_path(path):
return _replace_file_extension(path, ".phn.txt")
return _replace_file_extension(path, _get_phone_extension())
_total_durations = {}
@ -101,7 +107,7 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False):
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
if len(metadata) == 0:
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
def key( dir, id ):
if not cfg.dataset.use_hdf5:
@ -134,7 +140,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
key = f"/{type}{_get_hdf5_path(data_dir)}"
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
if isinstance(path, str):
path = Path(path)
@ -154,6 +160,10 @@ def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
def _load_quants(path) -> Tensor:
if _get_quant_extension() == ".dac":
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
# prune consecutive spaces
@ -162,8 +172,12 @@ def _cleanup_phones( phones, targets=[" "]):
@cache
def _get_phones(path, language="en"):
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
content = _cleanup_phones( content )
if _get_quant_extension() == ".json":
metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read())
content = metadata["phonemes"]
else:
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
content = _cleanup_phones( content )
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
def _interleaved_reorder(l, fn):
@ -807,11 +821,12 @@ def create_dataset_hdf5( skip_existing=True ):
files = os.listdir(f'{root}/{name}/')
# grab IDs for every file
ids = { ".".join(file.split(".")[:-2]) for file in files }
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
for id in tqdm(ids, desc=f"Processing {name}"):
try:
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
if not audio_exists or not text_exists:
continue
@ -831,21 +846,34 @@ def create_dataset_hdf5( skip_existing=True ):
# audio
if audios:
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t()
if _get_quant_extension() == ".dac":
if "audio" in group:
del group["audio"]
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
metadata[id]["metadata"] = qnt["metadata"]
else:
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
duration = qnt.shape[0] / 75
if "audio" in group:
del group["audio"]
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
group.attrs['duration'] = qnt.shape[0] # / 75
metadata[id]["duration"] = qnt.shape[0] # / 75
group.attrs['duration'] = duration
metadata[id]["duration"] = duration
else:
group.attrs['duration'] = 0
metadata[id]["duration"] = 0
# text
if texts:
"""
content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ")
if _get_quant_extension() == ".json":
j_son = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
content = j_son["phonemes"]
else:
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
for s in set(phones):
if s not in symmap:
@ -858,7 +886,6 @@ def create_dataset_hdf5( skip_existing=True ):
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
group.create_dataset('transcription', data=txt, compression='lzf', chunks=True)
"""
group.attrs['phonemes'] = len(phn)
metadata[id]["phones"] = len(phn)

View File

@ -66,22 +66,22 @@ try:
# to-do, original implementation
"""
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
"""
resample_fn = recons.resample
loudness_fn = recons.loudness
# If audio is > 10 minutes long, use the ffmpeg versions
if recons.signal_duration >= 10 * 60 * 60:
resample_fn = recons.ffmpeg_resample
loudness_fn = recons.ffmpeg_loudness
recons.normalize(obj.input_db)
resample_fn(obj.sample_rate)
recons = recons[..., : obj.original_length]
loudness_fn()
recons.audio_data = recons.audio_data.reshape(
-1, obj.channels, obj.original_length
)
self.padding = original_padding
return recons
@ -228,7 +228,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
)
return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate
return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate
kwargs = {}