converting over to a different intermediary dataset format
This commit is contained in:
parent
4f5c9e518a
commit
8214aa23d7
|
@ -63,11 +63,17 @@ def get_task_symmap():
|
|||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_quant_extension():
|
||||
return ".dac"
|
||||
|
||||
def _get_phone_extension():
|
||||
return ".json"
|
||||
|
||||
def _get_quant_path(path):
|
||||
return _replace_file_extension(path, ".qnt.pt")
|
||||
return _replace_file_extension(path, _get_quant_extension())
|
||||
|
||||
def _get_phone_path(path):
|
||||
return _replace_file_extension(path, ".phn.txt")
|
||||
return _replace_file_extension(path, _get_phone_extension())
|
||||
|
||||
_total_durations = {}
|
||||
|
||||
|
@ -101,7 +107,7 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
|||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
||||
|
||||
def key( dir, id ):
|
||||
if not cfg.dataset.use_hdf5:
|
||||
|
@ -134,7 +140,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|||
key = f"/{type}{_get_hdf5_path(data_dir)}"
|
||||
return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
|
||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
|
@ -154,6 +160,10 @@ def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ):
|
|||
return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else []
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
if _get_quant_extension() == ".dac":
|
||||
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
||||
|
||||
return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16)
|
||||
|
||||
# prune consecutive spaces
|
||||
|
@ -162,8 +172,12 @@ def _cleanup_phones( phones, targets=[" "]):
|
|||
|
||||
@cache
|
||||
def _get_phones(path, language="en"):
|
||||
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
|
||||
content = _cleanup_phones( content )
|
||||
if _get_quant_extension() == ".json":
|
||||
metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read())
|
||||
content = metadata["phonemes"]
|
||||
else:
|
||||
content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ")
|
||||
content = _cleanup_phones( content )
|
||||
return ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
|
@ -807,11 +821,12 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
files = os.listdir(f'{root}/{name}/')
|
||||
|
||||
# grab IDs for every file
|
||||
ids = { ".".join(file.split(".")[:-2]) for file in files }
|
||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}"):
|
||||
try:
|
||||
audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True
|
||||
audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||
|
||||
if not audio_exists or not text_exists:
|
||||
continue
|
||||
|
@ -831,21 +846,34 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
# audio
|
||||
if audios:
|
||||
qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t()
|
||||
qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||
codes = torch.from_numpy(qnt["codes"].astype(int))[0].t()
|
||||
|
||||
if "audio" in group:
|
||||
del group["audio"]
|
||||
if _get_quant_extension() == ".dac":
|
||||
if "audio" in group:
|
||||
del group["audio"]
|
||||
duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"]
|
||||
metadata[id]["metadata"] = qnt["metadata"]
|
||||
else:
|
||||
qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t()
|
||||
duration = qnt.shape[0] / 75
|
||||
|
||||
group.create_dataset('audio', data=qnt.numpy(), compression='lzf')
|
||||
group.attrs['duration'] = qnt.shape[0] # / 75
|
||||
metadata[id]["duration"] = qnt.shape[0] # / 75
|
||||
|
||||
group.attrs['duration'] = duration
|
||||
metadata[id]["duration"] = duration
|
||||
else:
|
||||
group.attrs['duration'] = 0
|
||||
metadata[id]["duration"] = 0
|
||||
|
||||
# text
|
||||
if texts:
|
||||
"""
|
||||
content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ")
|
||||
if _get_quant_extension() == ".json":
|
||||
j_son = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||
content = j_son["phonemes"]
|
||||
else:
|
||||
content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ")
|
||||
|
||||
phones = [f"<s>"] + [ " " if not p else p for p in content ] + [f"</s>"]
|
||||
for s in set(phones):
|
||||
if s not in symmap:
|
||||
|
@ -858,7 +886,6 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
group.create_dataset('text', data=phn, compression='lzf', chunks=True)
|
||||
group.create_dataset('transcription', data=txt, compression='lzf', chunks=True)
|
||||
"""
|
||||
|
||||
group.attrs['phonemes'] = len(phn)
|
||||
metadata[id]["phones"] = len(phn)
|
||||
|
|
|
@ -66,22 +66,22 @@ try:
|
|||
|
||||
# to-do, original implementation
|
||||
"""
|
||||
resample_fn = recons.resample
|
||||
loudness_fn = recons.loudness
|
||||
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if recons.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = recons.ffmpeg_resample
|
||||
loudness_fn = recons.ffmpeg_loudness
|
||||
|
||||
recons.normalize(obj.input_db)
|
||||
resample_fn(obj.sample_rate)
|
||||
recons = recons[..., : obj.original_length]
|
||||
loudness_fn()
|
||||
recons.audio_data = recons.audio_data.reshape(
|
||||
-1, obj.channels, obj.original_length
|
||||
)
|
||||
"""
|
||||
resample_fn = recons.resample
|
||||
loudness_fn = recons.loudness
|
||||
|
||||
# If audio is > 10 minutes long, use the ffmpeg versions
|
||||
if recons.signal_duration >= 10 * 60 * 60:
|
||||
resample_fn = recons.ffmpeg_resample
|
||||
loudness_fn = recons.ffmpeg_loudness
|
||||
|
||||
recons.normalize(obj.input_db)
|
||||
resample_fn(obj.sample_rate)
|
||||
recons = recons[..., : obj.original_length]
|
||||
loudness_fn()
|
||||
recons.audio_data = recons.audio_data.reshape(
|
||||
-1, obj.channels, obj.original_length
|
||||
)
|
||||
self.padding = original_padding
|
||||
return recons
|
||||
|
||||
|
@ -228,7 +228,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N
|
|||
dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version,
|
||||
)
|
||||
|
||||
return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate
|
||||
return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate
|
||||
|
||||
|
||||
kwargs = {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user