store metrics and only recalculate them if the output file is newer than the metrics file
This commit is contained in:
parent
0c69e798f7
commit
20b87bfbd0
|
@ -127,7 +127,7 @@ def get_random_prompts( validation=False, min_length=0, tokenized=False ):
|
|||
text_string = metadata["text"] if "text" in metadata else ""
|
||||
duration = metadata['duration'] if "duration" in metadata else 0
|
||||
else:
|
||||
_, metadata = _load_quants(path, return_metadata=True)
|
||||
_, metadata = _load_artifact(path, return_metadata=True)
|
||||
metadata = process_artifact_metadata( { "metadata": metadata } )
|
||||
text_string = metadata["text"] if "text" in metadata else ""
|
||||
duration = metadata['duration'] if "duration" in metadata else 0
|
||||
|
@ -564,17 +564,14 @@ def _replace_file_extension(path, suffix):
|
|||
path = Path(path)
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def _get_quant_extension():
|
||||
def _get_artifact_extension():
|
||||
return ".dac" if cfg.audio_backend == "dac" else ".enc"
|
||||
|
||||
def _get_phone_extension():
|
||||
return ".json" # if cfg.audio_backend == "dac" else ".phn.txt"
|
||||
def _get_metadata_extension():
|
||||
return ".json"
|
||||
|
||||
def _get_quant_path(path):
|
||||
return _replace_file_extension(path, _get_quant_extension())
|
||||
|
||||
def _get_phone_path(path):
|
||||
return _replace_file_extension(path, _get_phone_extension())
|
||||
def _get_artifact_path(path):
|
||||
return _replace_file_extension(path, _get_artifact_extension())
|
||||
|
||||
_durations_map = {}
|
||||
def _get_duration_map( type="training" ):
|
||||
|
@ -627,7 +624,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
metadata = json_read( metadata_path )
|
||||
|
||||
if len(metadata) == 0:
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||
|
||||
def _validate( id, entry ):
|
||||
phones = entry['phones'] if "phones" in entry else 0
|
||||
|
@ -671,37 +668,18 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
|||
|
||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||
|
||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
||||
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
||||
|
||||
def _load_quants(path, return_metadata=False) -> Tensor:
|
||||
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
||||
def _load_artifact(path, return_metadata=False) -> Tensor:
|
||||
qnt = np.load(_get_artifact_path(path), allow_pickle=True)[()]
|
||||
if return_metadata:
|
||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
|
||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
||||
|
||||
# prune consecutive spaces
|
||||
def _cleanup_phones( phones, targets=[" "]):
|
||||
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
|
||||
|
||||
@cache
|
||||
def _get_phones(path):
|
||||
phone_path = _get_phone_path(path)
|
||||
quant_path = _get_quant_path(path)
|
||||
if phone_path.exists():
|
||||
#metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
|
||||
metadata = json_read(phone_path)
|
||||
elif quant_path.exists():
|
||||
_, metadata = _load_quants( path, return_metadata=True )
|
||||
else:
|
||||
raise Exception(f"Could not load phonemes: {path}")
|
||||
|
||||
content = metadata["phonemes"]
|
||||
return "".join(content)
|
||||
|
||||
def _interleaved_reorder(l, fn):
|
||||
groups = defaultdict(list)
|
||||
for e in l:
|
||||
|
@ -991,7 +969,7 @@ class Dataset(_Dataset):
|
|||
key = _get_hdf5_path(path)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path, return_metadata=False)
|
||||
qnt = _load_artifact(path, return_metadata=False)
|
||||
return qnt
|
||||
|
||||
def sample_speakers(self, ignore=[]):
|
||||
|
@ -1026,7 +1004,7 @@ class Dataset(_Dataset):
|
|||
tone = metadata["tone"] if "tone" in metadata else None
|
||||
"""
|
||||
else:
|
||||
resps, metadata = _load_quants(path, return_metadata=True)
|
||||
resps, metadata = _load_artifact(path, return_metadata=True)
|
||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||
|
||||
"""
|
||||
|
@ -1112,7 +1090,7 @@ class Dataset(_Dataset):
|
|||
key = _get_hdf5_path(path)
|
||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||
else:
|
||||
qnt = _load_quants(path, return_metadata=False)
|
||||
qnt = _load_artifact(path, return_metadata=False)
|
||||
|
||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||
|
@ -1184,7 +1162,7 @@ class Dataset(_Dataset):
|
|||
if cfg.dataset.retokenize_text and "phonemes" in metadata:
|
||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||
else:
|
||||
resps, metadata = _load_quants(path, return_metadata=True)
|
||||
resps, metadata = _load_artifact(path, return_metadata=True)
|
||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||
|
||||
lang = metadata["language"] if "language" in metadata else None
|
||||
|
@ -1613,13 +1591,13 @@ def create_dataset_metadata( skip_existing=False ):
|
|||
files = os.listdir(f'{root}/{name}/')
|
||||
|
||||
# grab IDs for every file
|
||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
||||
ids = { file.replace(_get_artifact_extension(), "").replace(_get_metadata_extension(), "") for file in files }
|
||||
|
||||
wrote = False
|
||||
|
||||
for id in tqdm(ids, desc=f"Processing {name}", disable=True):
|
||||
try:
|
||||
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
|
||||
quant_path = Path(f'{root}/{name}/{id}{_get_artifact_extension()}')
|
||||
|
||||
if audios and not quant_path.exists():
|
||||
continue
|
||||
|
@ -1696,7 +1674,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
files = os.listdir(f'{root}/{name}/')
|
||||
|
||||
# grab IDs for every file
|
||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
||||
ids = { file.replace(_get_artifact_extension(), "").replace(_get_metadata_extension(), "") for file in files }
|
||||
|
||||
"""
|
||||
# rephonemizes if you fuck up and use and old tokenizer...
|
||||
|
@ -1724,8 +1702,8 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose):
|
||||
try:
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_artifact_extension()}') if audios else True
|
||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_metadata_extension()}') if texts else True
|
||||
|
||||
if not quant_exists:
|
||||
continue
|
||||
|
@ -1744,7 +1722,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
# audio
|
||||
if audios:
|
||||
artifact = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
||||
artifact = np.load(f'{root}/{name}/{id}{_get_artifact_extension()}', allow_pickle=True)[()]
|
||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
|
||||
utterance_metadata = process_artifact_metadata( artifact )
|
||||
|
@ -1757,7 +1735,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
# to-do: ensure I can remove this block
|
||||
if texts:
|
||||
if not utterance_metadata and text_exists:
|
||||
utterance_metadata = json_read(f'{root}/{name}/{id}{_get_phone_extension()}')
|
||||
utterance_metadata = json_read(f'{root}/{name}/{id}{_get_metadata_extension()}')
|
||||
|
||||
phn = "".join(utterance_metadata["phonemes"])
|
||||
phn = cfg.tokenizer.encode(phn)
|
||||
|
@ -1883,7 +1861,7 @@ if __name__ == "__main__":
|
|||
continue
|
||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||
else:
|
||||
_, metadata = _load_quants(path, return_metadata=True)
|
||||
_, metadata = _load_artifact(path, return_metadata=True)
|
||||
|
||||
phonemes = metadata["phonemes"]
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from .data import create_train_dataloader, create_val_dataloader, get_random_pro
|
|||
from .emb.qnt import decode_to_file
|
||||
from .metrics import wer, sim_o
|
||||
from .utils import setup_logging
|
||||
from .utils.io import json_read, json_write
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
|
@ -348,6 +349,7 @@ def main():
|
|||
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
||||
prompt = dir / "prompt.wav"
|
||||
reference = dir / "reference.wav"
|
||||
metrics_path = dir / "metrics.json"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs['suffix']}.wav"
|
||||
external_sources = [ dir / "out" / f"{source}.wav" for source in sources ]
|
||||
|
@ -374,15 +376,19 @@ def main():
|
|||
|
||||
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||
if args.comparison:
|
||||
if (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing):
|
||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||
|
||||
if should_generate:
|
||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||
|
||||
metrics_inputs.append((text, language, out_path_comparison, reference))
|
||||
metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path))
|
||||
|
||||
if (args.skip_existing and not out_path.exists()) or not (args.skip_existing):
|
||||
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||
|
||||
if should_generate:
|
||||
inputs.append((text, prompt, language, out_path))
|
||||
|
||||
metrics_inputs.append((text, language, out_path, reference))
|
||||
metrics_inputs.append((text, language, out_path, reference, metrics_path))
|
||||
|
||||
outputs.append((k, samples))
|
||||
|
||||
|
@ -393,10 +399,19 @@ def main():
|
|||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||
|
||||
metrics_map = {}
|
||||
total_metrics = (0, 0)
|
||||
for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
||||
|
||||
if calculate:
|
||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||
|
||||
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
||||
json_write( metrics, metrics_path )
|
||||
else:
|
||||
metrics = json_read( metrics_path )
|
||||
wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"]
|
||||
|
||||
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
|
||||
|
||||
# collate entries into HTML
|
||||
|
|
|
@ -23,7 +23,6 @@ except Exception as e:
|
|||
langdetect = None
|
||||
print(f'Error while importing langdetect: {str(e)}')
|
||||
|
||||
@cache
|
||||
def detect_language( text ):
|
||||
if langdetect is None:
|
||||
raise Exception('langdetect is not installed.')
|
||||
|
@ -34,7 +33,6 @@ def _get_graphs(path):
|
|||
graphs = f.read()
|
||||
return graphs
|
||||
|
||||
@cache
|
||||
def coerce_to_hiragana( runes, sep="" ):
|
||||
if pykakasi is None:
|
||||
raise Exception('pykakasi is not installed.')
|
||||
|
|
|
@ -19,7 +19,7 @@ from .config import cfg, Config
|
|||
from .models import get_models
|
||||
from .models.lora import enable_lora
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize, sentence_split
|
||||
from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split
|
||||
from .models import download_model, DEFAULT_MODEL_PATH
|
||||
|
||||
if deepspeed_available:
|
||||
|
|
Loading…
Reference in New Issue
Block a user