diff --git a/vall_e/data.py b/vall_e/data.py index e922461..fc49975 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -63,11 +63,17 @@ def get_task_symmap(): def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) +def _get_quant_extension(): + return ".dac" + +def _get_phone_extension(): + return ".json" + def _get_quant_path(path): - return _replace_file_extension(path, ".qnt.pt") + return _replace_file_extension(path, _get_quant_extension()) def _get_phone_path(path): - return _replace_file_extension(path, ".phn.txt") + return _replace_file_extension(path, _get_phone_extension()) _total_durations = {} @@ -101,7 +107,7 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False): metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read()) if len(metadata) == 0: - return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate ) + return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate ) def key( dir, id ): if not cfg.dataset.use_hdf5: @@ -134,7 +140,7 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ): key = f"/{type}{_get_hdf5_path(data_dir)}" return [ Path(f"{key}/{child.attrs['id']}") for child in cfg.hdf5[key].values() if not validate or _validate(child) ] if key in cfg.hdf5 else [] -def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ): +def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ): if isinstance(path, str): path = Path(path) @@ -154,6 +160,10 @@ def _get_paths_of_extensions( path, extensions=".qnt.pt", validate=False ): return [ p for p in list(path.iterdir()) if _validate(p) ] if path.exists() and path.is_dir() else [] def _load_quants(path) -> Tensor: + if _get_quant_extension() == ".dac": + qnt = np.load(_get_quant_path(path), allow_pickle=True)[()] + return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) + return torch.load(_get_quant_path(path))[0][:, :].t().to(torch.int16) # prune consecutive spaces @@ -162,8 +172,12 @@ def _cleanup_phones( phones, targets=[" "]): @cache def _get_phones(path, language="en"): - content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") - content = _cleanup_phones( content ) + if _get_quant_extension() == ".json": + metadata = json.loads(open(_get_phone_path(path), "r", encoding="utf-8").read()) + content = metadata["phonemes"] + else: + content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") + content = _cleanup_phones( content ) return [""] + [ " " if not p else p for p in content ] + [""] def _interleaved_reorder(l, fn): @@ -807,11 +821,12 @@ def create_dataset_hdf5( skip_existing=True ): files = os.listdir(f'{root}/{name}/') # grab IDs for every file - ids = { ".".join(file.split(".")[:-2]) for file in files } + ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + for id in tqdm(ids, desc=f"Processing {name}"): try: - audio_exists = os.path.exists(f'{root}/{name}/{id}.qnt.pt') if audios else True - text_exists = os.path.exists(f'{root}/{name}/{id}.phn.txt') if texts else True + audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True + text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True if not audio_exists or not text_exists: continue @@ -831,21 +846,34 @@ def create_dataset_hdf5( skip_existing=True ): # audio if audios: - qnt = torch.load(f'{root}/{name}/{id}.qnt.pt')[0].t() + qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] + codes = torch.from_numpy(qnt["codes"].astype(int))[0].t() - if "audio" in group: - del group["audio"] + if _get_quant_extension() == ".dac": + if "audio" in group: + del group["audio"] + duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"] + metadata[id]["metadata"] = qnt["metadata"] + else: + qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() + duration = qnt.shape[0] / 75 + group.create_dataset('audio', data=qnt.numpy(), compression='lzf') - group.attrs['duration'] = qnt.shape[0] # / 75 - metadata[id]["duration"] = qnt.shape[0] # / 75 + + group.attrs['duration'] = duration + metadata[id]["duration"] = duration else: group.attrs['duration'] = 0 metadata[id]["duration"] = 0 # text if texts: - """ - content = open(f'{root}/{name}/{id}.phn.txt', "r", encoding="utf-8") .read().split(" ") + if _get_quant_extension() == ".json": + j_son = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) + content = j_son["phonemes"] + else: + content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") + phones = [f""] + [ " " if not p else p for p in content ] + [f""] for s in set(phones): if s not in symmap: @@ -858,7 +886,6 @@ def create_dataset_hdf5( skip_existing=True ): group.create_dataset('text', data=phn, compression='lzf', chunks=True) group.create_dataset('transcription', data=txt, compression='lzf', chunks=True) - """ group.attrs['phonemes'] = len(phn) metadata[id]["phones"] = len(phn) diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 40827ad..c5af8ac 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -66,22 +66,22 @@ try: # to-do, original implementation """ - resample_fn = recons.resample - loudness_fn = recons.loudness - - # If audio is > 10 minutes long, use the ffmpeg versions - if recons.signal_duration >= 10 * 60 * 60: - resample_fn = recons.ffmpeg_resample - loudness_fn = recons.ffmpeg_loudness - - recons.normalize(obj.input_db) - resample_fn(obj.sample_rate) - recons = recons[..., : obj.original_length] - loudness_fn() - recons.audio_data = recons.audio_data.reshape( - -1, obj.channels, obj.original_length - ) """ + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) self.padding = original_padding return recons @@ -228,7 +228,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N dac_version = metadata["dac_version"] if isinstance(metadata, dict) else metadata.dac_version, ) - return model.decompress(artifact, verbose=False).audio_data[0], model.sample_rate + return model.decompress(artifact, verbose=False).audio_data[0], artifact.sample_rate kwargs = {}