diff --git a/vall_e/data.py b/vall_e/data.py index 85129a1..ca2a923 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -753,56 +753,98 @@ def create_train_val_dataloader(): return train_dl, subtrain_dl, val_dl # parse dataset into better to sample metadata -def create_dataset_metadata(): - # need to fix - if True: - return - - cfg.dataset.validate = False - cfg.dataset.use_hdf5 = False - - paths_by_spkr_name = {} - - paths_by_spkr_name |= _load_paths(cfg.dataset.training, "training") - paths_by_spkr_name |= _load_paths(cfg.dataset.validation, "validation") - paths_by_spkr_name |= _load_paths(cfg.dataset.noise, "noise") - - paths = list(itertools.chain.from_iterable(paths_by_spkr_name.values())) - - metadata = {} - for path in tqdm(paths, desc="Parsing paths"): - if isinstance(path, str): - print("str:", path) - path = Path(path) - - speaker = cfg.get_spkr(path) - if speaker not in metadata: - metadata[speaker] = {} - - if cfg.dataset.use_hdf5: - phones = cfg.hdf5[_get_hdf5_path(path)].attrs['phonemes'] - duration = cfg.hdf5[_get_hdf5_path(path)].attrs['duration'] - else: - phns_path = _get_phone_path(path) - qnts_path = _get_quant_path(path) - - phones = len(_get_phones(phns_path)) if phns_path.exists() else 0 - duration = _load_quants(qnts_path).shape[0] / 75 if qnts_path.exists() else 0 - - - metadata[speaker][path.name.split(".")[0]] = { - "phones": phones, - "duration": duration - } - - for speaker, paths in tqdm(paths_by_spkr_name.items(), desc="Writing metadata"): - if len(paths) == 0: - continue - with open(paths[0].parent / "metadata.json", "w", encoding="utf-8") as f: - f.write( json.dumps( metadata[speaker] ) ) +def create_dataset_metadata( skip_existing=False ): + symmap = get_phone_symmap() - with open(cfg.relpath / "metadata.json", "w", encoding="utf-8") as f: - f.write( json.dumps( metadata ) ) + root = str(cfg.data_dir) + metadata_root = str(cfg.metadata_dir) + + cfg.metadata_dir.mkdir(parents=True, exist_ok=True) + + def add( dir, type="training", audios=True, texts=True ): + name = str(dir) + name = name.replace(root, "") + + metadata_path = Path(f"{metadata_root}/{name}.json") + + metadata = {} if not metadata_path.exists() else json.loads(open(str(metadata_path), "r", encoding="utf-8").read()) + + if not os.path.isdir(f'{root}/{name}/'): + return + # tqdm.write(f'{root}/{name}') + files = os.listdir(f'{root}/{name}/') + + # grab IDs for every file + ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + + for id in tqdm(ids, desc=f"Processing {name}"): + try: + audio_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True + text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True + + if not audio_exists or not text_exists: + continue + + key = f'{type}/{name}/{id}' + + if skip_existing and key in metadata: + continue + + if id not in metadata: + metadata[id] = {} + + # audio + if audios: + if _get_quant_extension() == ".dac": + dac = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] + qnt = torch.from_numpy(dac["codes"].astype(int))[0].t().to(dtype=torch.int16) + + duration = dac["metadata"]["original_length"] / dac["metadata"]["sample_rate"] + metadata[id]["metadata"] = { + "original_length": dac["metadata"]["original_length"], + "sample_rate": dac["metadata"]["sample_rate"], + } + else: + qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() + duration = qnt.shape[0] / 75 + + metadata[id]["duration"] = duration + else: + metadata[id]["duration"] = 0 + + # text + if texts: + if _get_phone_extension() == ".json": + json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) + content = json_metadata["phonemes"] + txt = json_metadata["text"] + else: + content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") + txt = "" + + phn = cfg.tokenizer.encode("".join(content)) + phn = np.array(phn).astype(np.uint8) + + metadata[id]["phones"] = len(phn) + metadata[id]["transcription"] = txt + except Exception as e: + raise e + #pass + + with open(str(metadata_path), "w", encoding="utf-8") as f: + f.write( json.dumps( metadata ) ) + + # training + for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): + add( data_dir, type="training" ) + + # validation + for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'): + add( data_dir, type="validation" ) + + # noise + for data_dir in tqdm(sorted(cfg.dataset.noise), desc='Processing Noise'): + add( data_dir, type="noise", texts=False ) # parse yaml to create an hdf5 file def create_dataset_hdf5( skip_existing=True ):