diff --git a/vall_e/data.py b/vall_e/data.py index 79c2355..0863b76 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -127,7 +127,7 @@ def get_random_prompts( validation=False, min_length=0, tokenized=False ): text_string = metadata["text"] if "text" in metadata else "" duration = metadata['duration'] if "duration" in metadata else 0 else: - _, metadata = _load_quants(path, return_metadata=True) + _, metadata = _load_artifact(path, return_metadata=True) metadata = process_artifact_metadata( { "metadata": metadata } ) text_string = metadata["text"] if "text" in metadata else "" duration = metadata['duration'] if "duration" in metadata else 0 @@ -564,17 +564,14 @@ def _replace_file_extension(path, suffix): path = Path(path) return (path.parent / path.name.split(".")[0]).with_suffix(suffix) -def _get_quant_extension(): +def _get_artifact_extension(): return ".dac" if cfg.audio_backend == "dac" else ".enc" -def _get_phone_extension(): - return ".json" # if cfg.audio_backend == "dac" else ".phn.txt" +def _get_metadata_extension(): + return ".json" -def _get_quant_path(path): - return _replace_file_extension(path, _get_quant_extension()) - -def _get_phone_path(path): - return _replace_file_extension(path, _get_phone_extension()) +def _get_artifact_path(path): + return _replace_file_extension(path, _get_artifact_extension()) _durations_map = {} def _get_duration_map( type="training" ): @@ -627,7 +624,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): metadata = json_read( metadata_path ) if len(metadata) == 0: - return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate ) + return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate ) def _validate( id, entry ): phones = entry['phones'] if "phones" in entry else 0 @@ -671,37 +668,18 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ): return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else [] -def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ): +def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ): if isinstance(path, str): path = Path(path) return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else [] -def _load_quants(path, return_metadata=False) -> Tensor: - qnt = np.load(_get_quant_path(path), allow_pickle=True)[()] +def _load_artifact(path, return_metadata=False) -> Tensor: + qnt = np.load(_get_artifact_path(path), allow_pickle=True)[()] if return_metadata: return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"] return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16) -# prune consecutive spaces -def _cleanup_phones( phones, targets=[" "]): - return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ] - -@cache -def _get_phones(path): - phone_path = _get_phone_path(path) - quant_path = _get_quant_path(path) - if phone_path.exists(): - #metadata = json.loads(open(phone_path, "r", encoding="utf-8").read()) - metadata = json_read(phone_path) - elif quant_path.exists(): - _, metadata = _load_quants( path, return_metadata=True ) - else: - raise Exception(f"Could not load phonemes: {path}") - - content = metadata["phonemes"] - return "".join(content) - def _interleaved_reorder(l, fn): groups = defaultdict(list) for e in l: @@ -991,7 +969,7 @@ class Dataset(_Dataset): key = _get_hdf5_path(path) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: - qnt = _load_quants(path, return_metadata=False) + qnt = _load_artifact(path, return_metadata=False) return qnt def sample_speakers(self, ignore=[]): @@ -1026,7 +1004,7 @@ class Dataset(_Dataset): tone = metadata["tone"] if "tone" in metadata else None """ else: - resps, metadata = _load_quants(path, return_metadata=True) + resps, metadata = _load_artifact(path, return_metadata=True) text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype) """ @@ -1112,7 +1090,7 @@ class Dataset(_Dataset): key = _get_hdf5_path(path) qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: - qnt = _load_quants(path, return_metadata=False) + qnt = _load_artifact(path, return_metadata=False) if 0 < trim_length and trim_length < qnt.shape[0]: qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device ) @@ -1184,7 +1162,7 @@ class Dataset(_Dataset): if cfg.dataset.retokenize_text and "phonemes" in metadata: text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype) else: - resps, metadata = _load_quants(path, return_metadata=True) + resps, metadata = _load_artifact(path, return_metadata=True) text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype) lang = metadata["language"] if "language" in metadata else None @@ -1613,13 +1591,13 @@ def create_dataset_metadata( skip_existing=False ): files = os.listdir(f'{root}/{name}/') # grab IDs for every file - ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + ids = { file.replace(_get_artifact_extension(), "").replace(_get_metadata_extension(), "") for file in files } wrote = False for id in tqdm(ids, desc=f"Processing {name}", disable=True): try: - quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}') + quant_path = Path(f'{root}/{name}/{id}{_get_artifact_extension()}') if audios and not quant_path.exists(): continue @@ -1696,7 +1674,7 @@ def create_dataset_hdf5( skip_existing=True ): files = os.listdir(f'{root}/{name}/') # grab IDs for every file - ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files } + ids = { file.replace(_get_artifact_extension(), "").replace(_get_metadata_extension(), "") for file in files } """ # rephonemizes if you fuck up and use and old tokenizer... @@ -1724,8 +1702,8 @@ def create_dataset_hdf5( skip_existing=True ): for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose): try: - quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True - text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True + quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_artifact_extension()}') if audios else True + text_exists = os.path.exists(f'{root}/{name}/{id}{_get_metadata_extension()}') if texts else True if not quant_exists: continue @@ -1744,7 +1722,7 @@ def create_dataset_hdf5( skip_existing=True ): # audio if audios: - artifact = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] + artifact = np.load(f'{root}/{name}/{id}{_get_artifact_extension()}', allow_pickle=True)[()] qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16) utterance_metadata = process_artifact_metadata( artifact ) @@ -1757,7 +1735,7 @@ def create_dataset_hdf5( skip_existing=True ): # to-do: ensure I can remove this block if texts: if not utterance_metadata and text_exists: - utterance_metadata = json_read(f'{root}/{name}/{id}{_get_phone_extension()}') + utterance_metadata = json_read(f'{root}/{name}/{id}{_get_metadata_extension()}') phn = "".join(utterance_metadata["phonemes"]) phn = cfg.tokenizer.encode(phn) @@ -1883,7 +1861,7 @@ if __name__ == "__main__": continue metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() } else: - _, metadata = _load_quants(path, return_metadata=True) + _, metadata = _load_artifact(path, return_metadata=True) phonemes = metadata["phonemes"] diff --git a/vall_e/demo.py b/vall_e/demo.py index e676cf0..8e626e6 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -32,6 +32,7 @@ from .data import create_train_dataloader, create_val_dataloader, get_random_pro from .emb.qnt import decode_to_file from .metrics import wer, sim_o from .utils import setup_logging +from .utils.io import json_read, json_write from tqdm import tqdm, trange @@ -348,6 +349,7 @@ def main(): language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en" prompt = dir / "prompt.wav" reference = dir / "reference.wav" + metrics_path = dir / "metrics.json" out_path = dir / "out" / "ours.wav" out_path_comparison = dir / "out" / f"ours_{comparison_kwargs['suffix']}.wav" external_sources = [ dir / "out" / f"{source}.wav" for source in sources ] @@ -374,15 +376,19 @@ def main(): # segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs) if args.comparison: - if (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing): + should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing) + + if should_generate: comparison_inputs.append((text, prompt, language, out_path_comparison)) - metrics_inputs.append((text, language, out_path_comparison, reference)) + metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path)) - if (args.skip_existing and not out_path.exists()) or not (args.skip_existing): + should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing) + + if should_generate: inputs.append((text, prompt, language, out_path)) - metrics_inputs.append((text, language, out_path, reference)) + metrics_inputs.append((text, language, out_path, reference, metrics_path)) outputs.append((k, samples)) @@ -393,10 +399,19 @@ def main(): process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) ) metrics_map = {} - total_metrics = (0, 0) - for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"): - wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) - sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) + for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"): + calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime) + + if calculate: + wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model ) + sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model ) + + metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score} + json_write( metrics, metrics_path ) + else: + metrics = json_read( metrics_path ) + wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"] + metrics_map[out_path] = (wer_score, cer_score, sim_o_score) # collate entries into HTML diff --git a/vall_e/emb/g2p.py b/vall_e/emb/g2p.py index 7e1f5cb..61acae8 100755 --- a/vall_e/emb/g2p.py +++ b/vall_e/emb/g2p.py @@ -23,7 +23,6 @@ except Exception as e: langdetect = None print(f'Error while importing langdetect: {str(e)}') -@cache def detect_language( text ): if langdetect is None: raise Exception('langdetect is not installed.') @@ -34,7 +33,6 @@ def _get_graphs(path): graphs = f.read() return graphs -@cache def coerce_to_hiragana( runes, sep="" ): if pykakasi is None: raise Exception('pykakasi is not installed.') diff --git a/vall_e/inference.py b/vall_e/inference.py index d8d9141..5e50e07 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -19,7 +19,7 @@ from .config import cfg, Config from .models import get_models from .models.lora import enable_lora from .engines import load_engines, deepspeed_available -from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize, sentence_split +from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split from .models import download_model, DEFAULT_MODEL_PATH if deepspeed_available: