store metrics and only recalculate them if the output file is newer than the metrics file
This commit is contained in:
parent
0c69e798f7
commit
20b87bfbd0
|
@ -127,7 +127,7 @@ def get_random_prompts( validation=False, min_length=0, tokenized=False ):
|
||||||
text_string = metadata["text"] if "text" in metadata else ""
|
text_string = metadata["text"] if "text" in metadata else ""
|
||||||
duration = metadata['duration'] if "duration" in metadata else 0
|
duration = metadata['duration'] if "duration" in metadata else 0
|
||||||
else:
|
else:
|
||||||
_, metadata = _load_quants(path, return_metadata=True)
|
_, metadata = _load_artifact(path, return_metadata=True)
|
||||||
metadata = process_artifact_metadata( { "metadata": metadata } )
|
metadata = process_artifact_metadata( { "metadata": metadata } )
|
||||||
text_string = metadata["text"] if "text" in metadata else ""
|
text_string = metadata["text"] if "text" in metadata else ""
|
||||||
duration = metadata['duration'] if "duration" in metadata else 0
|
duration = metadata['duration'] if "duration" in metadata else 0
|
||||||
|
@ -564,17 +564,14 @@ def _replace_file_extension(path, suffix):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
def _get_quant_extension():
|
def _get_artifact_extension():
|
||||||
return ".dac" if cfg.audio_backend == "dac" else ".enc"
|
return ".dac" if cfg.audio_backend == "dac" else ".enc"
|
||||||
|
|
||||||
def _get_phone_extension():
|
def _get_metadata_extension():
|
||||||
return ".json" # if cfg.audio_backend == "dac" else ".phn.txt"
|
return ".json"
|
||||||
|
|
||||||
def _get_quant_path(path):
|
def _get_artifact_path(path):
|
||||||
return _replace_file_extension(path, _get_quant_extension())
|
return _replace_file_extension(path, _get_artifact_extension())
|
||||||
|
|
||||||
def _get_phone_path(path):
|
|
||||||
return _replace_file_extension(path, _get_phone_extension())
|
|
||||||
|
|
||||||
_durations_map = {}
|
_durations_map = {}
|
||||||
def _get_duration_map( type="training" ):
|
def _get_duration_map( type="training" ):
|
||||||
|
@ -627,7 +624,7 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
||||||
metadata = json_read( metadata_path )
|
metadata = json_read( metadata_path )
|
||||||
|
|
||||||
if len(metadata) == 0:
|
if len(metadata) == 0:
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_quant_extension(), validate )
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else _get_artifact_extension(), validate )
|
||||||
|
|
||||||
def _validate( id, entry ):
|
def _validate( id, entry ):
|
||||||
phones = entry['phones'] if "phones" in entry else 0
|
phones = entry['phones'] if "phones" in entry else 0
|
||||||
|
@ -671,37 +668,18 @@ def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||||
|
|
||||||
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
return [ Path(f"{key}/{id}") for id, entry in cfg.hdf5[key].items() if _validate(id, entry) ] if key in cfg.hdf5 else []
|
||||||
|
|
||||||
def _get_paths_of_extensions( path, extensions=_get_quant_extension(), validate=False ):
|
def _get_paths_of_extensions( path, extensions=_get_artifact_extension(), validate=False ):
|
||||||
if isinstance(path, str):
|
if isinstance(path, str):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
return [ p for p in list(path.iterdir()) ] if path.exists() and path.is_dir() else []
|
||||||
|
|
||||||
def _load_quants(path, return_metadata=False) -> Tensor:
|
def _load_artifact(path, return_metadata=False) -> Tensor:
|
||||||
qnt = np.load(_get_quant_path(path), allow_pickle=True)[()]
|
qnt = np.load(_get_artifact_path(path), allow_pickle=True)[()]
|
||||||
if return_metadata:
|
if return_metadata:
|
||||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
|
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16), qnt["metadata"]
|
||||||
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
return torch.from_numpy(qnt["codes"].astype(int))[0][:, :].t().to(torch.int16)
|
||||||
|
|
||||||
# prune consecutive spaces
|
|
||||||
def _cleanup_phones( phones, targets=[" "]):
|
|
||||||
return [ p for i, p in enumerate(phones) if p not in targets or ( p in targets and p != phones[i-1] ) ]
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def _get_phones(path):
|
|
||||||
phone_path = _get_phone_path(path)
|
|
||||||
quant_path = _get_quant_path(path)
|
|
||||||
if phone_path.exists():
|
|
||||||
#metadata = json.loads(open(phone_path, "r", encoding="utf-8").read())
|
|
||||||
metadata = json_read(phone_path)
|
|
||||||
elif quant_path.exists():
|
|
||||||
_, metadata = _load_quants( path, return_metadata=True )
|
|
||||||
else:
|
|
||||||
raise Exception(f"Could not load phonemes: {path}")
|
|
||||||
|
|
||||||
content = metadata["phonemes"]
|
|
||||||
return "".join(content)
|
|
||||||
|
|
||||||
def _interleaved_reorder(l, fn):
|
def _interleaved_reorder(l, fn):
|
||||||
groups = defaultdict(list)
|
groups = defaultdict(list)
|
||||||
for e in l:
|
for e in l:
|
||||||
|
@ -991,7 +969,7 @@ class Dataset(_Dataset):
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path, return_metadata=False)
|
qnt = _load_artifact(path, return_metadata=False)
|
||||||
return qnt
|
return qnt
|
||||||
|
|
||||||
def sample_speakers(self, ignore=[]):
|
def sample_speakers(self, ignore=[]):
|
||||||
|
@ -1026,7 +1004,7 @@ class Dataset(_Dataset):
|
||||||
tone = metadata["tone"] if "tone" in metadata else None
|
tone = metadata["tone"] if "tone" in metadata else None
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
resps, metadata = _load_quants(path, return_metadata=True)
|
resps, metadata = _load_artifact(path, return_metadata=True)
|
||||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -1112,7 +1090,7 @@ class Dataset(_Dataset):
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16)
|
||||||
else:
|
else:
|
||||||
qnt = _load_quants(path, return_metadata=False)
|
qnt = _load_artifact(path, return_metadata=False)
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < qnt.shape[0]:
|
if 0 < trim_length and trim_length < qnt.shape[0]:
|
||||||
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
qnt = trim( qnt, trim_length, reencode=cfg.dataset.reencode_on_concat, device=cfg.dataset.reencode_device )
|
||||||
|
@ -1184,7 +1162,7 @@ class Dataset(_Dataset):
|
||||||
if cfg.dataset.retokenize_text and "phonemes" in metadata:
|
if cfg.dataset.retokenize_text and "phonemes" in metadata:
|
||||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
else:
|
else:
|
||||||
resps, metadata = _load_quants(path, return_metadata=True)
|
resps, metadata = _load_artifact(path, return_metadata=True)
|
||||||
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
text = torch.tensor(tokenize( metadata["phonemes"] )).to(self.text_dtype)
|
||||||
|
|
||||||
lang = metadata["language"] if "language" in metadata else None
|
lang = metadata["language"] if "language" in metadata else None
|
||||||
|
@ -1613,13 +1591,13 @@ def create_dataset_metadata( skip_existing=False ):
|
||||||
files = os.listdir(f'{root}/{name}/')
|
files = os.listdir(f'{root}/{name}/')
|
||||||
|
|
||||||
# grab IDs for every file
|
# grab IDs for every file
|
||||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
ids = { file.replace(_get_artifact_extension(), "").replace(_get_metadata_extension(), "") for file in files }
|
||||||
|
|
||||||
wrote = False
|
wrote = False
|
||||||
|
|
||||||
for id in tqdm(ids, desc=f"Processing {name}", disable=True):
|
for id in tqdm(ids, desc=f"Processing {name}", disable=True):
|
||||||
try:
|
try:
|
||||||
quant_path = Path(f'{root}/{name}/{id}{_get_quant_extension()}')
|
quant_path = Path(f'{root}/{name}/{id}{_get_artifact_extension()}')
|
||||||
|
|
||||||
if audios and not quant_path.exists():
|
if audios and not quant_path.exists():
|
||||||
continue
|
continue
|
||||||
|
@ -1696,7 +1674,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
files = os.listdir(f'{root}/{name}/')
|
files = os.listdir(f'{root}/{name}/')
|
||||||
|
|
||||||
# grab IDs for every file
|
# grab IDs for every file
|
||||||
ids = { file.replace(_get_quant_extension(), "").replace(_get_phone_extension(), "") for file in files }
|
ids = { file.replace(_get_artifact_extension(), "").replace(_get_metadata_extension(), "") for file in files }
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# rephonemizes if you fuck up and use and old tokenizer...
|
# rephonemizes if you fuck up and use and old tokenizer...
|
||||||
|
@ -1724,8 +1702,8 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose):
|
for id in tqdm(ids, desc=f"Processing {name}", disable=not verbose):
|
||||||
try:
|
try:
|
||||||
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_quant_extension()}') if audios else True
|
quant_exists = os.path.exists(f'{root}/{name}/{id}{_get_artifact_extension()}') if audios else True
|
||||||
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_phone_extension()}') if texts else True
|
text_exists = os.path.exists(f'{root}/{name}/{id}{_get_metadata_extension()}') if texts else True
|
||||||
|
|
||||||
if not quant_exists:
|
if not quant_exists:
|
||||||
continue
|
continue
|
||||||
|
@ -1744,7 +1722,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if audios:
|
if audios:
|
||||||
artifact = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()]
|
artifact = np.load(f'{root}/{name}/{id}{_get_artifact_extension()}', allow_pickle=True)[()]
|
||||||
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
qnt = torch.from_numpy(artifact["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||||
|
|
||||||
utterance_metadata = process_artifact_metadata( artifact )
|
utterance_metadata = process_artifact_metadata( artifact )
|
||||||
|
@ -1757,7 +1735,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
# to-do: ensure I can remove this block
|
# to-do: ensure I can remove this block
|
||||||
if texts:
|
if texts:
|
||||||
if not utterance_metadata and text_exists:
|
if not utterance_metadata and text_exists:
|
||||||
utterance_metadata = json_read(f'{root}/{name}/{id}{_get_phone_extension()}')
|
utterance_metadata = json_read(f'{root}/{name}/{id}{_get_metadata_extension()}')
|
||||||
|
|
||||||
phn = "".join(utterance_metadata["phonemes"])
|
phn = "".join(utterance_metadata["phonemes"])
|
||||||
phn = cfg.tokenizer.encode(phn)
|
phn = cfg.tokenizer.encode(phn)
|
||||||
|
@ -1883,7 +1861,7 @@ if __name__ == "__main__":
|
||||||
continue
|
continue
|
||||||
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
metadata = { f'{k}': f'{v}' for k, v in cfg.hdf5[key].attrs.items() }
|
||||||
else:
|
else:
|
||||||
_, metadata = _load_quants(path, return_metadata=True)
|
_, metadata = _load_artifact(path, return_metadata=True)
|
||||||
|
|
||||||
phonemes = metadata["phonemes"]
|
phonemes = metadata["phonemes"]
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ from .data import create_train_dataloader, create_val_dataloader, get_random_pro
|
||||||
from .emb.qnt import decode_to_file
|
from .emb.qnt import decode_to_file
|
||||||
from .metrics import wer, sim_o
|
from .metrics import wer, sim_o
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging
|
||||||
|
from .utils.io import json_read, json_write
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
@ -348,6 +349,7 @@ def main():
|
||||||
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
language = open(dir / "language.txt").read() if (dir / "language.txt").exists() else "en"
|
||||||
prompt = dir / "prompt.wav"
|
prompt = dir / "prompt.wav"
|
||||||
reference = dir / "reference.wav"
|
reference = dir / "reference.wav"
|
||||||
|
metrics_path = dir / "metrics.json"
|
||||||
out_path = dir / "out" / "ours.wav"
|
out_path = dir / "out" / "ours.wav"
|
||||||
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs['suffix']}.wav"
|
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs['suffix']}.wav"
|
||||||
external_sources = [ dir / "out" / f"{source}.wav" for source in sources ]
|
external_sources = [ dir / "out" / f"{source}.wav" for source in sources ]
|
||||||
|
@ -374,15 +376,19 @@ def main():
|
||||||
|
|
||||||
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
# segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
if (args.skip_existing and not out_path_comparison.exists()) or not (args.skip_existing):
|
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||||
|
|
||||||
|
if should_generate:
|
||||||
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
comparison_inputs.append((text, prompt, language, out_path_comparison))
|
||||||
|
|
||||||
metrics_inputs.append((text, language, out_path_comparison, reference))
|
metrics_inputs.append((text, language, out_path_comparison, reference, metrics_path))
|
||||||
|
|
||||||
if (args.skip_existing and not out_path.exists()) or not (args.skip_existing):
|
should_generate = (args.skip_existing and not out_path.exists()) or not (args.skip_existing)
|
||||||
|
|
||||||
|
if should_generate:
|
||||||
inputs.append((text, prompt, language, out_path))
|
inputs.append((text, prompt, language, out_path))
|
||||||
|
|
||||||
metrics_inputs.append((text, language, out_path, reference))
|
metrics_inputs.append((text, language, out_path, reference, metrics_path))
|
||||||
|
|
||||||
outputs.append((k, samples))
|
outputs.append((k, samples))
|
||||||
|
|
||||||
|
@ -393,10 +399,19 @@ def main():
|
||||||
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
|
||||||
|
|
||||||
metrics_map = {}
|
metrics_map = {}
|
||||||
total_metrics = (0, 0)
|
for text, language, out_path, reference_path, metrics_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
||||||
for text, language, out_path, reference_path in tqdm(metrics_inputs, desc="Calculating metrics"):
|
calculate = not metrics_path.exists() or (metrics_path.stat().st_mtime < out_path.stat().st_mtime)
|
||||||
|
|
||||||
|
if calculate:
|
||||||
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
wer_score, cer_score = wer( out_path, text, language=language, device=tts.device, dtype=tts.dtype, model_name=args.transcription_model )
|
||||||
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
sim_o_score = sim_o( out_path, reference_path, device=tts.device, dtype=tts.dtype, model_name=args.speaker_similarity_model )
|
||||||
|
|
||||||
|
metrics = {"wer": wer_score, "cer": cer_score, "sim-o": sim_o_score}
|
||||||
|
json_write( metrics, metrics_path )
|
||||||
|
else:
|
||||||
|
metrics = json_read( metrics_path )
|
||||||
|
wer_score, cer_score, sim_o_score = metrics["wer"], metrics["cer"], metrics["sim-o"]
|
||||||
|
|
||||||
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
|
metrics_map[out_path] = (wer_score, cer_score, sim_o_score)
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
|
|
|
@ -23,7 +23,6 @@ except Exception as e:
|
||||||
langdetect = None
|
langdetect = None
|
||||||
print(f'Error while importing langdetect: {str(e)}')
|
print(f'Error while importing langdetect: {str(e)}')
|
||||||
|
|
||||||
@cache
|
|
||||||
def detect_language( text ):
|
def detect_language( text ):
|
||||||
if langdetect is None:
|
if langdetect is None:
|
||||||
raise Exception('langdetect is not installed.')
|
raise Exception('langdetect is not installed.')
|
||||||
|
@ -34,7 +33,6 @@ def _get_graphs(path):
|
||||||
graphs = f.read()
|
graphs = f.read()
|
||||||
return graphs
|
return graphs
|
||||||
|
|
||||||
@cache
|
|
||||||
def coerce_to_hiragana( runes, sep="" ):
|
def coerce_to_hiragana( runes, sep="" ):
|
||||||
if pykakasi is None:
|
if pykakasi is None:
|
||||||
raise Exception('pykakasi is not installed.')
|
raise Exception('pykakasi is not installed.')
|
||||||
|
|
|
@ -19,7 +19,7 @@ from .config import cfg, Config
|
||||||
from .models import get_models
|
from .models import get_models
|
||||||
from .models.lora import enable_lora
|
from .models.lora import enable_lora
|
||||||
from .engines import load_engines, deepspeed_available
|
from .engines import load_engines, deepspeed_available
|
||||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize, sentence_split
|
from .data import get_phone_symmap, get_lang_symmap, tokenize, sentence_split
|
||||||
from .models import download_model, DEFAULT_MODEL_PATH
|
from .models import download_model, DEFAULT_MODEL_PATH
|
||||||
|
|
||||||
if deepspeed_available:
|
if deepspeed_available:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user